PLTNUM / scripts /datasets.py
sagawa's picture
Upload 17 files
4321e7e verified
raw
history blame
4.24 kB
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from augmentation import (
mask_augmentation,
random_change_augmentation,
random_delete_augmentation,
truncate_augmentation,
)
def tokenize_input(cfg, text):
inputs = cfg.tokenizer(
text,
add_special_tokens=True,
max_length=cfg.max_length,
padding="max_length",
truncation=True,
return_offsets_mapping=False,
return_attention_mask=True,
)
for k, v in inputs.items():
inputs[k] = torch.tensor(v, dtype=torch.long)
return inputs
def one_hot_encoding(aa, amino_acids, cfg):
aa = aa[: cfg.max_length].ljust(cfg.max_length, " ")
one_hot = np.zeros((len(aa), len(amino_acids)))
for i, a in enumerate(aa):
if a in amino_acids:
one_hot[i, amino_acids.index(a)] = 1
return one_hot
def one_hot_encode_input(text, cfg):
inputs = one_hot_encoding(text, ("A","C","D","E","F","G","H","I","K","L","M","N","P","Q","R","S","T","V","W","Y"," "), cfg)
return torch.tensor(inputs, dtype=torch.float)
class PLTNUMDataset(Dataset):
def __init__(self, cfg, df, train=True):
self.df = df
self.cfg = cfg
self.train = train
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
data = self.df.iloc[idx]
aas = self._adjust_sequence_length(data[self.cfg.sequence_col])
if self.train:
aas = self._apply_augmentation(aas)
aas = aas.replace("__", "<pad>")
inputs = tokenize_input(self.cfg, aas)
if "target" in data:
return inputs, torch.tensor(data["target"], dtype=torch.float32)
return inputs, np.nan
def _adjust_sequence_length(self, aas):
max_length = (self.cfg.max_length - 2) * self.cfg.token_length
if len(aas) > max_length:
if self.cfg.used_sequence == "left":
return aas[: max_length]
elif self.cfg.used_sequence == "right":
return aas[-max_length:]
elif self.cfg.used_sequence == "both":
half_max_len = max_length // 2
return aas[:half_max_len] + "__" + aas[-half_max_len:]
elif self.cfg.used_sequence == "internal":
offset = (len(aas) - max_length) // 2
return aas[offset:offset + max_length]
return aas
def _apply_augmentation(self, aas):
if self.cfg.random_change_ratio > 0:
aas = random_change_augmentation(aas, self.cfg)
if (
random.random() <= self.cfg.random_delete_prob
) and self.cfg.random_delete_ratio > 0:
aas = random_delete_augmentation(aas, self.cfg)
if (random.random() <= self.cfg.mask_prob) and self.cfg.mask_ratio > 0:
aas = mask_augmentation(aas, self.cfg)
if random.random() <= self.cfg.truncate_augmentation_prob:
aas = truncate_augmentation(aas, self.cfg)
return aas
class LSTMDataset(Dataset):
def __init__(self, cfg, df, train=True):
self.df = df
self.cfg = cfg
self.train = train
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
data = self.df.iloc[idx]
aas = data[self.cfg.sequence_col]
aas = self._adjust_sequence_length(aas)
aas = aas.replace("__", "<pad>")
inputs = one_hot_encode_input(aas, self.cfg)
return inputs, torch.tensor(data["target"], dtype=torch.float32)
def _adjust_sequence_length(self, aas):
max_length = (self.cfg.max_length - 2) * self.cfg.token_length
if len(aas) > max_length:
if self.cfg.used_sequence == "left":
return aas[:max_length]
elif self.cfg.used_sequence == "right":
return aas[-max_length:]
elif self.cfg.used_sequence == "both":
half_max_len = max_length // 2
return aas[:half_max_len] + "__" + aas[-half_max_len:]
elif self.cfg.used_sequence == "internal":
offset = (len(aas) - max_length) // 2
return aas[offset:offset + max_length]
return aas