File size: 4,236 Bytes
4321e7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from augmentation import (
mask_augmentation,
random_change_augmentation,
random_delete_augmentation,
truncate_augmentation,
)
def tokenize_input(cfg, text):
inputs = cfg.tokenizer(
text,
add_special_tokens=True,
max_length=cfg.max_length,
padding="max_length",
truncation=True,
return_offsets_mapping=False,
return_attention_mask=True,
)
for k, v in inputs.items():
inputs[k] = torch.tensor(v, dtype=torch.long)
return inputs
def one_hot_encoding(aa, amino_acids, cfg):
aa = aa[: cfg.max_length].ljust(cfg.max_length, " ")
one_hot = np.zeros((len(aa), len(amino_acids)))
for i, a in enumerate(aa):
if a in amino_acids:
one_hot[i, amino_acids.index(a)] = 1
return one_hot
def one_hot_encode_input(text, cfg):
inputs = one_hot_encoding(text, ("A","C","D","E","F","G","H","I","K","L","M","N","P","Q","R","S","T","V","W","Y"," "), cfg)
return torch.tensor(inputs, dtype=torch.float)
class PLTNUMDataset(Dataset):
def __init__(self, cfg, df, train=True):
self.df = df
self.cfg = cfg
self.train = train
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
data = self.df.iloc[idx]
aas = self._adjust_sequence_length(data[self.cfg.sequence_col])
if self.train:
aas = self._apply_augmentation(aas)
aas = aas.replace("__", "<pad>")
inputs = tokenize_input(self.cfg, aas)
if "target" in data:
return inputs, torch.tensor(data["target"], dtype=torch.float32)
return inputs, np.nan
def _adjust_sequence_length(self, aas):
max_length = (self.cfg.max_length - 2) * self.cfg.token_length
if len(aas) > max_length:
if self.cfg.used_sequence == "left":
return aas[: max_length]
elif self.cfg.used_sequence == "right":
return aas[-max_length:]
elif self.cfg.used_sequence == "both":
half_max_len = max_length // 2
return aas[:half_max_len] + "__" + aas[-half_max_len:]
elif self.cfg.used_sequence == "internal":
offset = (len(aas) - max_length) // 2
return aas[offset:offset + max_length]
return aas
def _apply_augmentation(self, aas):
if self.cfg.random_change_ratio > 0:
aas = random_change_augmentation(aas, self.cfg)
if (
random.random() <= self.cfg.random_delete_prob
) and self.cfg.random_delete_ratio > 0:
aas = random_delete_augmentation(aas, self.cfg)
if (random.random() <= self.cfg.mask_prob) and self.cfg.mask_ratio > 0:
aas = mask_augmentation(aas, self.cfg)
if random.random() <= self.cfg.truncate_augmentation_prob:
aas = truncate_augmentation(aas, self.cfg)
return aas
class LSTMDataset(Dataset):
def __init__(self, cfg, df, train=True):
self.df = df
self.cfg = cfg
self.train = train
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
data = self.df.iloc[idx]
aas = data[self.cfg.sequence_col]
aas = self._adjust_sequence_length(aas)
aas = aas.replace("__", "<pad>")
inputs = one_hot_encode_input(aas, self.cfg)
return inputs, torch.tensor(data["target"], dtype=torch.float32)
def _adjust_sequence_length(self, aas):
max_length = (self.cfg.max_length - 2) * self.cfg.token_length
if len(aas) > max_length:
if self.cfg.used_sequence == "left":
return aas[:max_length]
elif self.cfg.used_sequence == "right":
return aas[-max_length:]
elif self.cfg.used_sequence == "both":
half_max_len = max_length // 2
return aas[:half_max_len] + "__" + aas[-half_max_len:]
elif self.cfg.used_sequence == "internal":
offset = (len(aas) - max_length) // 2
return aas[offset:offset + max_length]
return aas
|