|
import regex as re
|
|
import torch
|
|
import numpy as np
|
|
import random
|
|
import collections
|
|
|
|
class Encoder():
|
|
|
|
def __init__(self, max_length=500, add_bos=True, add_eos=True, feature_size=32):
|
|
self.vocab_encoder = torch.load('pubchem_canon_zinc_final_vocab_sorted_curated.pth')
|
|
|
|
self.max_length = max_length
|
|
self.min_length = 1
|
|
self.mod_length = 42
|
|
self.mlm_probability = .15
|
|
self.avg_length = 66
|
|
self.tail = 122
|
|
self.b0_cache=collections.deque()
|
|
self.b1_cache=collections.deque()
|
|
self.b2_cache=collections.deque()
|
|
self.b3_cache=collections.deque()
|
|
self.bucket0=collections.deque()
|
|
self.bucket1=collections.deque()
|
|
self.bucket2=collections.deque()
|
|
self.bucket3=collections.deque()
|
|
if feature_size == 32:
|
|
self.b0_max=1100
|
|
self.b1_max=700
|
|
self.b2_max=150
|
|
self.b3_max=50
|
|
else:
|
|
self.b0_max=1382
|
|
self.b1_max=871
|
|
self.b2_max=516
|
|
self.b3_max=311
|
|
values = list(self.vocab_encoder.values())
|
|
num_top = 0
|
|
middle_top = 0
|
|
bottom = 0
|
|
for count in values:
|
|
if count > 100000:
|
|
num_top += 1
|
|
if count > 50:
|
|
middle_top += 1
|
|
middle_top = middle_top - num_top
|
|
self.cutoffs = [num_top+4, middle_top]
|
|
self.char2id = {"<bos>":0, "<eos>":1, "<pad>":2, "<mask>":3}
|
|
self.id2char = {0:"<bos>", 1:"<eos>", 2:"<pad>", 3:"<mask>"}
|
|
self.pad = self.char2id['<pad>']
|
|
self.mask = self.char2id['<mask>']
|
|
self.eos = self.char2id['<eos>']
|
|
self.bos = self.char2id['<bos>']
|
|
pos = 0
|
|
for key, value in self.vocab_encoder.items():
|
|
|
|
self.char2id[key] = pos+4
|
|
self.id2char[pos+4] = key
|
|
pos += 1
|
|
self.char2id["<unk>"] = pos + 4
|
|
self.id2char[pos+4] = "<unk>"
|
|
self.pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
|
|
self.regex = re.compile(self.pattern)
|
|
self.add_bos = add_bos
|
|
self.add_eos = add_eos
|
|
|
|
|
|
def encode(self, char):
|
|
|
|
|
|
if self.add_bos == True:
|
|
char = ['<bos>'] + char
|
|
if self.add_eos == True:
|
|
char = char + ['<eos>']
|
|
|
|
return torch.tensor([self.char2id[word] for word in char])
|
|
|
|
def encoder(self, tokens):
|
|
|
|
return [self.encode(mol) for mol in tokens]
|
|
|
|
def process_text(self, text):
|
|
|
|
|
|
mod_length = self.mod_length
|
|
avg_length = self.avg_length
|
|
for mol in text:
|
|
|
|
if '\n' in mol['text']:
|
|
print('carriage return in mol')
|
|
raw_regex = self.regex.findall(mol['text'].strip('\n'))
|
|
length = len(raw_regex)
|
|
if length > self.min_length and length < mod_length:
|
|
if len(self.bucket0) < self.b0_max:
|
|
self.bucket0.append(raw_regex)
|
|
else:
|
|
self.b0_cache.append(raw_regex)
|
|
elif length >= mod_length and length < avg_length:
|
|
if len(self.bucket1) < self.b1_max:
|
|
self.bucket1.append(raw_regex)
|
|
else:
|
|
self.b1_cache.append(raw_regex)
|
|
elif length >= avg_length and length < self.tail:
|
|
if len(self.bucket2) < self.b2_max:
|
|
self.bucket2.append(raw_regex)
|
|
else:
|
|
self.b2_cache.append(raw_regex)
|
|
elif length >= self.tail and length < self.max_length:
|
|
if len(self.bucket3) < self.b3_max:
|
|
self.bucket3.append(raw_regex)
|
|
else:
|
|
self.b3_cache.append(raw_regex)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(self.bucket0) < self.b0_max and len(self.b0_cache) > 0:
|
|
cache_size = len(self.b0_cache)
|
|
max_margin = self.b0_max-len(self.bucket0)
|
|
range0 = min(cache_size, max_margin)
|
|
outbucket0 = [self.bucket0.pop() for item in range(len(self.bucket0))] + [self.b0_cache.pop() for i in range(range0)]
|
|
|
|
|
|
else:
|
|
outbucket0 = [self.bucket0.pop() for item in range(len(self.bucket0))]
|
|
|
|
if len(self.bucket1) < self.b1_max and len(self.b1_cache) > 0:
|
|
cache_size = len(self.b1_cache)
|
|
max_margin = self.b1_max-len(self.bucket1)
|
|
range1 = min(cache_size, max_margin)
|
|
outbucket1 = [self.bucket1.pop() for item in range(len(self.bucket1))] + [self.b1_cache.pop() for i in range(range1)]
|
|
else:
|
|
outbucket1 = [self.bucket1.pop() for item in range(len(self.bucket1))]
|
|
|
|
if len(self.bucket2) < self.b2_max and len(self.b2_cache) > 0:
|
|
cache_size = len(self.b2_cache)
|
|
max_margin = self.b2_max-len(self.bucket2)
|
|
range2 = min(cache_size, max_margin)
|
|
outbucket2 = [self.bucket2.pop() for item in range(len(self.bucket2))] + [self.b2_cache.pop() for i in range(range2)]
|
|
else:
|
|
outbucket2 = [self.bucket2.pop() for item in range(len(self.bucket2))]
|
|
|
|
if len(self.bucket3) < self.b3_max and len(self.b3_cache) > 0:
|
|
cache_size = len(self.b3_cache)
|
|
max_margin = self.b3_max-len(self.bucket3)
|
|
range3 = min(cache_size, max_margin)
|
|
outbucket3 = [self.bucket3.pop() for item in range(len(self.bucket3))] + [self.b3_cache.pop() for i in range(range3)]
|
|
else:
|
|
outbucket3 = [self.bucket3.pop() for item in range(len(self.bucket3))]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return outbucket0, outbucket1, outbucket2, outbucket3
|
|
|
|
def mask_tokens( self, inputs, special_tokens_mask= None):
|
|
"""
|
|
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
|
"""
|
|
labels = inputs.clone()
|
|
|
|
probability_matrix = torch.full(labels.size(), self.mlm_probability)
|
|
if special_tokens_mask is None:
|
|
special_tokens_mask = [
|
|
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
|
]
|
|
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
|
|
else:
|
|
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
|
|
|
|
|
|
|
|
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
|
|
masked_indices = torch.bernoulli(probability_matrix).bool()
|
|
labels[~masked_indices] = -100
|
|
|
|
|
|
indices_replaced = torch.bernoulli(torch.full(labels.size(), 0.8)).bool() & masked_indices
|
|
inputs[indices_replaced] = self.mask
|
|
|
|
|
|
indices_random = torch.bernoulli(torch.full(labels.size(), 0.5)).bool() & masked_indices & ~indices_replaced
|
|
random_words = torch.randint(len(self.char2id.keys()), labels.size(), dtype=torch.long)
|
|
inputs[indices_random] = random_words[indices_random]
|
|
|
|
|
|
return inputs, labels
|
|
def pack_tensors(self, tokens):
|
|
array_ids = self.encoder(tokens)
|
|
array = torch.nn.utils.rnn.pad_sequence(array_ids, batch_first=True, padding_value=self.pad)
|
|
lengths = (array!=self.pad).sum(dim=-1)
|
|
|
|
special_token_mask = [list(map(lambda x: 1 if x in [self.bos, self.eos, self.pad] else 0, stuff)) for stuff in array.tolist()]
|
|
masked_array, masked_labels = self.mask_tokens(array, special_token_mask)
|
|
return masked_array, masked_labels, array_ids, lengths
|
|
def process(self, text):
|
|
arrays = []
|
|
lengths = []
|
|
targets = []
|
|
arrays_ids = []
|
|
for tokens in self.process_text(text):
|
|
if len(tokens) > 0:
|
|
array, target, array_ids, lgt = self.pack_tensors(tokens)
|
|
arrays.append(array)
|
|
targets.append(target)
|
|
arrays_ids.append(array_ids)
|
|
lengths.append(lgt)
|
|
return arrays, targets, arrays_ids, lengths
|
|
|
|
if __name__ == '__main__':
|
|
|
|
text_encoder = Encoder()
|
|
|