import random import numpy as np import torch from torch.utils.data import Dataset from augmentation import ( mask_augmentation, random_change_augmentation, random_delete_augmentation, truncate_augmentation, ) def tokenize_input(cfg, text): inputs = cfg.tokenizer( text, add_special_tokens=True, max_length=cfg.max_length, padding="max_length", truncation=True, return_offsets_mapping=False, return_attention_mask=True, ) for k, v in inputs.items(): inputs[k] = torch.tensor(v, dtype=torch.long) return inputs def one_hot_encoding(aa, amino_acids, cfg): aa = aa[: cfg.max_length].ljust(cfg.max_length, " ") one_hot = np.zeros((len(aa), len(amino_acids))) for i, a in enumerate(aa): if a in amino_acids: one_hot[i, amino_acids.index(a)] = 1 return one_hot def one_hot_encode_input(text, cfg): inputs = one_hot_encoding(text, ("A","C","D","E","F","G","H","I","K","L","M","N","P","Q","R","S","T","V","W","Y"," "), cfg) return torch.tensor(inputs, dtype=torch.float) class PLTNUMDataset(Dataset): def __init__(self, cfg, df, train=True): self.df = df self.cfg = cfg self.train = train def __len__(self): return len(self.df) def __getitem__(self, idx): data = self.df.iloc[idx] aas = self._adjust_sequence_length(data[self.cfg.sequence_col]) if self.train: aas = self._apply_augmentation(aas) aas = aas.replace("__", "") inputs = tokenize_input(self.cfg, aas) if "target" in data: return inputs, torch.tensor(data["target"], dtype=torch.float32) return inputs, np.nan def _adjust_sequence_length(self, aas): max_length = (self.cfg.max_length - 2) * self.cfg.token_length if len(aas) > max_length: if self.cfg.used_sequence == "left": return aas[: max_length] elif self.cfg.used_sequence == "right": return aas[-max_length:] elif self.cfg.used_sequence == "both": half_max_len = max_length // 2 return aas[:half_max_len] + "__" + aas[-half_max_len:] elif self.cfg.used_sequence == "internal": offset = (len(aas) - max_length) // 2 return aas[offset:offset + max_length] return aas def _apply_augmentation(self, aas): if self.cfg.random_change_ratio > 0: aas = random_change_augmentation(aas, self.cfg) if ( random.random() <= self.cfg.random_delete_prob ) and self.cfg.random_delete_ratio > 0: aas = random_delete_augmentation(aas, self.cfg) if (random.random() <= self.cfg.mask_prob) and self.cfg.mask_ratio > 0: aas = mask_augmentation(aas, self.cfg) if random.random() <= self.cfg.truncate_augmentation_prob: aas = truncate_augmentation(aas, self.cfg) return aas class LSTMDataset(Dataset): def __init__(self, cfg, df, train=True): self.df = df self.cfg = cfg self.train = train def __len__(self): return len(self.df) def __getitem__(self, idx): data = self.df.iloc[idx] aas = data[self.cfg.sequence_col] aas = self._adjust_sequence_length(aas) aas = aas.replace("__", "") inputs = one_hot_encode_input(aas, self.cfg) return inputs, torch.tensor(data["target"], dtype=torch.float32) def _adjust_sequence_length(self, aas): max_length = (self.cfg.max_length - 2) * self.cfg.token_length if len(aas) > max_length: if self.cfg.used_sequence == "left": return aas[:max_length] elif self.cfg.used_sequence == "right": return aas[-max_length:] elif self.cfg.used_sequence == "both": half_max_len = max_length // 2 return aas[:half_max_len] + "__" + aas[-half_max_len:] elif self.cfg.used_sequence == "internal": offset = (len(aas) - max_length) // 2 return aas[offset:offset + max_length] return aas