Spaces:
Sleeping
Sleeping
import os | |
import glob | |
import torch | |
import random | |
import selfies as sf | |
from rdkit import Chem | |
from datasets import load_dataset | |
from transformers import T5EncoderModel | |
from torch.utils.data import DistributedSampler, DataLoader, Dataset | |
def get_dataloader(dataset, batchsize, rank, world_size): | |
sampler = DistributedSampler( | |
dataset, num_replicas=world_size, rank=rank, shuffle=True | |
) | |
def collate(batch): | |
selfies_ids = [i["selfies_ids"] for i in batch] | |
caption_state = [i["caption_state"] for i in batch] | |
caption_mask = [i["caption_mask"] for i in batch] | |
corrupted_selfies_ids = [i["corrupted_selfies_ids"] for i in batch] | |
return ( | |
torch.concat(selfies_ids, dim=0), | |
torch.concat(caption_state, dim=0), | |
torch.concat(caption_mask, dim=0), | |
torch.concat(corrupted_selfies_ids, dim=0), | |
) | |
dataloader = DataLoader( | |
dataset, | |
batch_size=batchsize, | |
shuffle=False, | |
collate_fn=collate, | |
sampler=sampler, | |
) | |
def cycle(): | |
ec = 0 | |
while True: | |
dataloader.sampler.set_epoch(ec) | |
for i in dataloader: | |
yield i | |
ec += 1 | |
return iter(cycle()) | |
class Lang2molDataset_train(Dataset): | |
def __init__( | |
self, | |
dir, | |
tokenizer, | |
split, | |
dataset_name, | |
pre=None, | |
prob=0, | |
load_state=True, | |
corrupt_prob=0.4, | |
token_max_length=256, | |
): | |
super().__init__() | |
self.dir = dir | |
self.tokenizer = tokenizer | |
self.split = split | |
self.pre = pre | |
self.prob = prob | |
self.corrupt_prob = corrupt_prob | |
self.token_max_length = token_max_length | |
self.dataset_name = dataset_name | |
self.ori_data = self.create_data() | |
self.load_state = load_state | |
self.model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol") | |
self.model.to("cuda") | |
self.model.eval() | |
def create_data(self): | |
try: | |
dataset = load_dataset( | |
self.dataset_name, | |
token=True, | |
split=self.split, | |
).sort("id") | |
except: | |
dataset = load_dataset( | |
self.dataset_name, | |
use_auth_token=True, | |
split=self.split, | |
).sort("id") | |
return [ | |
(int(sample_id), sample_selfies, sample_caption, sample_smiles) | |
for (sample_id, sample_selfies, sample_caption, sample_smiles) in zip( | |
dataset["id"], | |
dataset["selfies"], | |
dataset["caption"], | |
dataset["smiles"], | |
) | |
] | |
def __len__(self): | |
return len(self.ori_data) | |
def permute(self, selfies): | |
if random.random() < self.prob: | |
return changeorder(selfies, shuffle=True) | |
else: | |
return selfies | |
def __getitem__(self, idx): | |
data = self.ori_data[idx] | |
sample = { | |
"id": data[0], | |
"selfies": self.permute(data[1]), | |
"caption": data[2], | |
"smiles": data[3], | |
} | |
# Molecules | |
output_molecule = self.tokenizer( | |
sample["selfies"], | |
max_length=self.token_max_length, | |
truncation=True, | |
padding="max_length", | |
add_special_tokens=True, | |
return_tensors="pt", | |
return_attention_mask=True, | |
) | |
sample["selfies_ids"] = output_molecule["input_ids"] | |
sample["corrupted_selfies_ids"] = sample["selfies_ids"] | |
# Captions | |
output_caption = self.tokenizer( | |
sample["caption"], | |
max_length=self.token_max_length, | |
truncation=True, | |
padding="max_length", | |
add_special_tokens=True, | |
return_tensors="pt", | |
return_attention_mask=True, | |
) | |
sample["caption_state"] = self.model( | |
input_ids=output_caption["input_ids"].to("cuda"), | |
attention_mask=output_caption["attention_mask"].to("cuda"), | |
).last_hidden_state | |
sample["caption_mask"] = output_caption["attention_mask"] | |
return sample | |
class Lang2molDataset_eval(Dataset): | |
def __init__( | |
self, | |
dir, | |
tokenizer, | |
split, | |
dataset_name, | |
pre=None, | |
prob=0, | |
load_state=True, | |
corrupt_prob=0.4, | |
token_max_length=512, | |
): | |
super().__init__() | |
self.dir = dir | |
self.tokenizer = tokenizer | |
self.split = split | |
self.pre = pre | |
self.prob = prob | |
self.corrupt_prob = corrupt_prob | |
self.token_max_length = token_max_length | |
self.dataset_name = dataset_name | |
self.ori_data = self.create_data() | |
self.load_state = load_state | |
self.model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol") | |
self.model.to("cuda") | |
self.model.eval() | |
def create_data(self): | |
try: | |
dataset = load_dataset( | |
self.dataset_name, | |
token=True, | |
split=self.split, | |
).sort("id") | |
except: | |
dataset = load_dataset( | |
self.dataset_name, | |
use_auth_token=True, | |
split=self.split, | |
).sort("id") | |
return [ | |
(int(sample_id), sample_selfies, sample_caption, sample_smiles) | |
for (sample_id, sample_selfies, sample_caption, sample_smiles) in zip( | |
dataset["id"], | |
dataset["selfies"], | |
dataset["caption"], | |
dataset["smiles"], | |
) | |
] | |
def __len__(self): | |
return len(self.ori_data) | |
def permute(self, selfies): | |
if random.random() < self.prob: | |
return changeorder(selfies, shuffle=True) | |
else: | |
return selfies | |
def __getitem__(self, idx): | |
data = self.ori_data[idx] | |
sample = { | |
"id": data[0], | |
"selfies": self.permute(data[1]), | |
"caption": data[2], | |
"smiles": data[3], | |
} | |
output_caption = self.tokenizer( | |
sample["caption"], | |
max_length=self.token_max_length, | |
truncation=True, | |
padding="max_length", | |
add_special_tokens=True, | |
return_tensors="pt", | |
return_attention_mask=True, | |
) | |
sample["caption_state"] = self.model( | |
input_ids=output_caption["input_ids"].to("cuda"), | |
attention_mask=output_caption["attention_mask"].to("cuda"), | |
).last_hidden_state | |
sample["caption_mask"] = output_caption["attention_mask"] | |
return sample | |
class Lang2molDataset_submission(Dataset): | |
def __init__( | |
self, | |
dir, | |
tokenizer, | |
split, | |
dataset_name, | |
pre=None, | |
prob=0, | |
load_state=True, | |
corrupt_prob=0.4, | |
token_max_length=256, | |
): | |
super().__init__() | |
self.dir = dir | |
self.tokenizer = tokenizer | |
self.split = split | |
self.pre = pre | |
self.prob = prob | |
self.corrupt_prob = corrupt_prob | |
self.token_max_length = token_max_length | |
self.dataset_name = dataset_name | |
self.ori_data = self.create_data() | |
self.load_state = load_state | |
self.model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol") | |
self.model.to("cuda") | |
self.model.eval() | |
def create_data(self): | |
try: | |
dataset = load_dataset( | |
self.dataset_name, | |
token=True, | |
split=self.split, | |
) | |
except: | |
dataset = load_dataset( | |
self.dataset_name, | |
use_auth_token=True, | |
split=self.split, | |
) | |
return [sample_caption for sample_caption in dataset["caption"]] | |
def __len__(self): | |
return len(self.ori_data) | |
def permute(self, selfies): | |
if random.random() < self.prob: | |
return changeorder(selfies, shuffle=True) | |
else: | |
return selfies | |
def __getitem__(self, idx): | |
sample = {"caption": self.ori_data[idx]} | |
# Captions | |
output_caption = self.tokenizer( | |
sample["caption"], | |
max_length=self.token_max_length, | |
truncation=True, | |
padding="max_length", | |
add_special_tokens=True, | |
return_tensors="pt", | |
return_attention_mask=True, | |
) | |
sample["caption_state"] = self.model( | |
input_ids=output_caption["input_ids"].to("cuda"), | |
attention_mask=output_caption["attention_mask"].to("cuda"), | |
).last_hidden_state | |
sample["caption_mask"] = output_caption["attention_mask"] | |
return sample | |
def changeorder(selfies, shuffle): | |
smiles = sf.encoder(selfies) | |
mol = Chem.MolFromSmiles(smiles) | |
if mol is None: | |
return selfies | |
Chem.Kekulize(mol) | |
atom_indices = [atom.GetIdx() for atom in mol.GetAtoms()] | |
if shuffle: | |
random.shuffle(atom_indices) | |
reordered_mol = Chem.RenumberAtoms(mol, atom_indices) | |
new_smiles = Chem.MolToSmiles(reordered_mol, kekuleSmiles=True) | |
new_selfies = sf.decoder(new_smiles) | |
return new_selfies | |