|
import os |
|
import sys |
|
import time |
|
import random |
|
from itertools import chain |
|
from collections import Counter |
|
import numpy as np |
|
import torch |
|
from torch.nn.utils.rnn import pad_sequence |
|
from transformers.data.data_collator import DataCollator |
|
from multiprocessing import Pool |
|
import mmap |
|
from torch.utils.data import Dataset |
|
|
|
class IUPACDataset(Dataset): |
|
def __init__(self, dataset_dir='./',dataset_filename="iupacs_logp.txt", tokenizer=None,max_length=None,target_col=None, |
|
dataset_size=None,iupac_name_col="iupac"): |
|
self.dataset_dir = dataset_dir |
|
self.tokenizer = tokenizer |
|
self.target_col = target_col |
|
self.max_length = max_length |
|
self.dataset_size = dataset_size |
|
self.dataset_filename = dataset_filename |
|
|
|
|
|
self.dataset_fn = os.path.join(self.dataset_dir,self.dataset_filename) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
line_offsets = [] |
|
|
|
self.data_mm = np.memmap(self.dataset_fn, dtype=np.uint8, mode="r") |
|
|
|
|
|
chunksize = int(1e9) |
|
for i in range(0, len(self.data_mm), chunksize): |
|
chunk = self.data_mm[i:i + chunksize] |
|
|
|
|
|
newlines = np.nonzero(chunk == 0x0a)[0] |
|
line_offsets.append(i + newlines + 1) |
|
if self.dataset_size is not None and i > self.dataset_size: |
|
|
|
break |
|
|
|
self.line_offsets = np.hstack(line_offsets) |
|
|
|
if (self.dataset_size is not None |
|
and self.dataset_size > self.line_offsets.shape[0]): |
|
msg = "specified dataset_size {}, but the dataset only has {} items" |
|
raise ValueError(msg.format(self.dataset_size, |
|
self.line_offsets.shape[0])) |
|
|
|
|
|
header_line = bytes(self.data_mm[0:self.line_offsets[0]]) |
|
headers = header_line.decode("utf8").strip().split("|") |
|
|
|
|
|
try: |
|
self.name_col_id = headers.index(iupac_name_col) |
|
except ValueError as e: |
|
raise RuntimeError("Expecting a column called '{}' " |
|
"that contains IUPAC names".format(iupac_name_col)) |
|
self.target_col_id = None |
|
if self.target_col is not None: |
|
try: |
|
self.target_col_id = headers.index(self.target_col) |
|
except ValueError as e: |
|
raise RuntimeError("User supplied target col " + target_col + \ |
|
"but column is not present in data file") |
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
|
|
if self.dataset_size is not None and idx > self.dataset_size: |
|
msg = "provided index {} is larger than dataset size {}" |
|
raise IndexError(msg.format(idx, self.dataset_size)) |
|
|
|
start = self.line_offsets[idx] |
|
end = self.line_offsets[idx + 1] |
|
line = bytes(self.data_mm[start:end]) |
|
line = line.decode("utf8").strip().split("|") |
|
name = line[self.name_col_id] |
|
|
|
|
|
target = None |
|
if self.target_col_id is not None: |
|
target = line[self.target_col_id] |
|
if self.target_col == "Log P" and len(target) == 0: |
|
target = 3.16 |
|
else: |
|
target = float(target) |
|
|
|
tokenized = self.tokenizer(name) |
|
input_ids = torch.tensor(tokenized["input_ids"]) |
|
|
|
iupac_unk = torch.tensor([self.tokenizer._convert_token_to_id(self.tokenizer.unk_token)]) |
|
input_ids = torch.tensor(input_ids) |
|
input_ids = torch.cat([iupac_unk,input_ids]) |
|
|
|
return_dict = {} |
|
return_dict["input_ids"] = input_ids |
|
return_dict["labels"] = input_ids |
|
|
|
|
|
if self.max_length is not None: |
|
return_dict["input_ids"] = return_dict["input_ids"][:self.max_length] |
|
return_dict["labels"] = return_dict["labels"][:self.max_length] |
|
|
|
return return_dict |
|
|
|
def __len__(self): |
|
if self.dataset_size is None: |
|
return len(self.line_offsets) - 1 |
|
else: |
|
return self.dataset_size |