File size: 1,964 Bytes
4321e7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import random
def random_change_augmentation(aas, cfg):
residue_tokens = ("A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y")
stracture_aware_tokens = ("a", "c", "d", "e", "f", "g", "h", "i", "k", "l", "m", "n", "p", "q", "r", "s", "t", "v", "w", "y")
length = len(aas)
swap_indices = random.sample(
range(length), int(length * cfg.random_change_ratio)
)
new_aas = ""
for i, aa in enumerate(aas):
if i in swap_indices:
if aas[i] in residue_tokens:
new_aas += random.choice(residue_tokens)
elif aas[i] in stracture_aware_tokens:
new_aas += random.choice(stracture_aware_tokens)
else:
new_aas += aa
return new_aas
def mask_augmentation(aas, cfg):
length = len(aas)
swap_indices = random.sample(
range(0, length // cfg.token_length),
int(length // cfg.token_length * cfg.mask_ratio),
)
for ith in swap_indices:
aas = (
aas[: ith * cfg.token_length]
+ "@" * cfg.token_length
+ aas[(ith + 1) * cfg.token_length :]
)
aas = aas.replace("@@", "<mask>").replace("@", "<mask>")
return aas
def random_delete_augmentation(aas, cfg):
length = len(aas)
swap_indices = random.sample(
range(0, length // cfg.token_length),
int(length // cfg.token_length * cfg.random_delete_ratio),
)
for ith in swap_indices:
aas = (
aas[: ith * cfg.token_length]
+ "@" * cfg.token_length
+ aas[(ith + 1) * cfg.token_length :]
)
aas = aas.replace("@@", "").replace("@", "")
return aas
def truncate_augmentation(aas, cfg):
length = len(aas)
if length > cfg.max_length:
diff = length - cfg.max_length
start = random.randint(0, diff)
return aas[start : start + cfg.max_length]
else:
return aas |