Lang2mol-Diff / src /scripts /mydatasets.py
ndhieunguyen's picture
Add application file
7dd9869
import os
import glob
import torch
import random
import selfies as sf
from rdkit import Chem
from datasets import load_dataset
from transformers import T5EncoderModel
from torch.utils.data import DistributedSampler, DataLoader, Dataset
def get_dataloader(dataset, batchsize, rank, world_size):
sampler = DistributedSampler(
dataset, num_replicas=world_size, rank=rank, shuffle=True
)
def collate(batch):
selfies_ids = [i["selfies_ids"] for i in batch]
caption_state = [i["caption_state"] for i in batch]
caption_mask = [i["caption_mask"] for i in batch]
corrupted_selfies_ids = [i["corrupted_selfies_ids"] for i in batch]
return (
torch.concat(selfies_ids, dim=0),
torch.concat(caption_state, dim=0),
torch.concat(caption_mask, dim=0),
torch.concat(corrupted_selfies_ids, dim=0),
)
dataloader = DataLoader(
dataset,
batch_size=batchsize,
shuffle=False,
collate_fn=collate,
sampler=sampler,
)
def cycle():
ec = 0
while True:
dataloader.sampler.set_epoch(ec)
for i in dataloader:
yield i
ec += 1
return iter(cycle())
class Lang2molDataset_train(Dataset):
def __init__(
self,
dir,
tokenizer,
split,
dataset_name,
pre=None,
prob=0,
load_state=True,
corrupt_prob=0.4,
token_max_length=256,
):
super().__init__()
self.dir = dir
self.tokenizer = tokenizer
self.split = split
self.pre = pre
self.prob = prob
self.corrupt_prob = corrupt_prob
self.token_max_length = token_max_length
self.dataset_name = dataset_name
self.ori_data = self.create_data()
self.load_state = load_state
self.model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
self.model.to("cuda")
self.model.eval()
def create_data(self):
try:
dataset = load_dataset(
self.dataset_name,
token=True,
split=self.split,
).sort("id")
except:
dataset = load_dataset(
self.dataset_name,
use_auth_token=True,
split=self.split,
).sort("id")
return [
(int(sample_id), sample_selfies, sample_caption, sample_smiles)
for (sample_id, sample_selfies, sample_caption, sample_smiles) in zip(
dataset["id"],
dataset["selfies"],
dataset["caption"],
dataset["smiles"],
)
]
def __len__(self):
return len(self.ori_data)
def permute(self, selfies):
if random.random() < self.prob:
return changeorder(selfies, shuffle=True)
else:
return selfies
def __getitem__(self, idx):
data = self.ori_data[idx]
sample = {
"id": data[0],
"selfies": self.permute(data[1]),
"caption": data[2],
"smiles": data[3],
}
# Molecules
output_molecule = self.tokenizer(
sample["selfies"],
max_length=self.token_max_length,
truncation=True,
padding="max_length",
add_special_tokens=True,
return_tensors="pt",
return_attention_mask=True,
)
sample["selfies_ids"] = output_molecule["input_ids"]
sample["corrupted_selfies_ids"] = sample["selfies_ids"]
# Captions
output_caption = self.tokenizer(
sample["caption"],
max_length=self.token_max_length,
truncation=True,
padding="max_length",
add_special_tokens=True,
return_tensors="pt",
return_attention_mask=True,
)
sample["caption_state"] = self.model(
input_ids=output_caption["input_ids"].to("cuda"),
attention_mask=output_caption["attention_mask"].to("cuda"),
).last_hidden_state
sample["caption_mask"] = output_caption["attention_mask"]
return sample
class Lang2molDataset_eval(Dataset):
def __init__(
self,
dir,
tokenizer,
split,
dataset_name,
pre=None,
prob=0,
load_state=True,
corrupt_prob=0.4,
token_max_length=512,
):
super().__init__()
self.dir = dir
self.tokenizer = tokenizer
self.split = split
self.pre = pre
self.prob = prob
self.corrupt_prob = corrupt_prob
self.token_max_length = token_max_length
self.dataset_name = dataset_name
self.ori_data = self.create_data()
self.load_state = load_state
self.model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
self.model.to("cuda")
self.model.eval()
def create_data(self):
try:
dataset = load_dataset(
self.dataset_name,
token=True,
split=self.split,
).sort("id")
except:
dataset = load_dataset(
self.dataset_name,
use_auth_token=True,
split=self.split,
).sort("id")
return [
(int(sample_id), sample_selfies, sample_caption, sample_smiles)
for (sample_id, sample_selfies, sample_caption, sample_smiles) in zip(
dataset["id"],
dataset["selfies"],
dataset["caption"],
dataset["smiles"],
)
]
def __len__(self):
return len(self.ori_data)
def permute(self, selfies):
if random.random() < self.prob:
return changeorder(selfies, shuffle=True)
else:
return selfies
def __getitem__(self, idx):
data = self.ori_data[idx]
sample = {
"id": data[0],
"selfies": self.permute(data[1]),
"caption": data[2],
"smiles": data[3],
}
output_caption = self.tokenizer(
sample["caption"],
max_length=self.token_max_length,
truncation=True,
padding="max_length",
add_special_tokens=True,
return_tensors="pt",
return_attention_mask=True,
)
sample["caption_state"] = self.model(
input_ids=output_caption["input_ids"].to("cuda"),
attention_mask=output_caption["attention_mask"].to("cuda"),
).last_hidden_state
sample["caption_mask"] = output_caption["attention_mask"]
return sample
class Lang2molDataset_submission(Dataset):
def __init__(
self,
dir,
tokenizer,
split,
dataset_name,
pre=None,
prob=0,
load_state=True,
corrupt_prob=0.4,
token_max_length=256,
):
super().__init__()
self.dir = dir
self.tokenizer = tokenizer
self.split = split
self.pre = pre
self.prob = prob
self.corrupt_prob = corrupt_prob
self.token_max_length = token_max_length
self.dataset_name = dataset_name
self.ori_data = self.create_data()
self.load_state = load_state
self.model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
self.model.to("cuda")
self.model.eval()
def create_data(self):
try:
dataset = load_dataset(
self.dataset_name,
token=True,
split=self.split,
)
except:
dataset = load_dataset(
self.dataset_name,
use_auth_token=True,
split=self.split,
)
return [sample_caption for sample_caption in dataset["caption"]]
def __len__(self):
return len(self.ori_data)
def permute(self, selfies):
if random.random() < self.prob:
return changeorder(selfies, shuffle=True)
else:
return selfies
def __getitem__(self, idx):
sample = {"caption": self.ori_data[idx]}
# Captions
output_caption = self.tokenizer(
sample["caption"],
max_length=self.token_max_length,
truncation=True,
padding="max_length",
add_special_tokens=True,
return_tensors="pt",
return_attention_mask=True,
)
sample["caption_state"] = self.model(
input_ids=output_caption["input_ids"].to("cuda"),
attention_mask=output_caption["attention_mask"].to("cuda"),
).last_hidden_state
sample["caption_mask"] = output_caption["attention_mask"]
return sample
def changeorder(selfies, shuffle):
smiles = sf.encoder(selfies)
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return selfies
Chem.Kekulize(mol)
atom_indices = [atom.GetIdx() for atom in mol.GetAtoms()]
if shuffle:
random.shuffle(atom_indices)
reordered_mol = Chem.RenumberAtoms(mol, atom_indices)
new_smiles = Chem.MolToSmiles(reordered_mol, kekuleSmiles=True)
new_selfies = sf.decoder(new_smiles)
return new_selfies