import regex as re import torch import numpy as np import random import collections class Encoder(): def __init__(self, max_length=500, add_bos=True, add_eos=True, feature_size=32): self.vocab_encoder = torch.load('pubchem_canon_zinc_final_vocab_sorted_curated.pth') self.max_length = max_length self.min_length = 1 self.mod_length = 42 self.mlm_probability = .15 self.avg_length = 66 self.tail = 122 self.b0_cache=collections.deque() self.b1_cache=collections.deque() self.b2_cache=collections.deque() self.b3_cache=collections.deque() self.bucket0=collections.deque() self.bucket1=collections.deque() self.bucket2=collections.deque() self.bucket3=collections.deque() if feature_size == 32: self.b0_max=1100 self.b1_max=700 self.b2_max=150 self.b3_max=50 else: self.b0_max=1382 self.b1_max=871 self.b2_max=516 self.b3_max=311 values = list(self.vocab_encoder.values()) num_top = 0 middle_top = 0 bottom = 0 for count in values: if count > 100000: num_top += 1 if count > 50: middle_top += 1 middle_top = middle_top - num_top self.cutoffs = [num_top+4, middle_top] self.char2id = {"":0, "":1, "":2, "":3} self.id2char = {0:"", 1:"", 2:"", 3:""} self.pad = self.char2id[''] self.mask = self.char2id[''] self.eos = self.char2id[''] self.bos = self.char2id[''] pos = 0 for key, value in self.vocab_encoder.items(): #for pos, key in enumerate(self.vocab_encoder.keys()): self.char2id[key] = pos+4 self.id2char[pos+4] = key pos += 1 self.char2id[""] = pos + 4 self.id2char[pos+4] = "" self.pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" self.regex = re.compile(self.pattern) self.add_bos = add_bos self.add_eos = add_eos #print(self.char2id) def encode(self, char): #if len(char) > self.max_length: # char = char[:self.max_length] if self.add_bos == True: char = [''] + char if self.add_eos == True: char = char + [''] return torch.tensor([self.char2id.get(word, self.char2id[""]) for word in char]) def encoder(self, tokens): #return *map(lambda x: self.encode(x), tokens) return [self.encode(mol) for mol in tokens] def process_text(self, text): #print(text) #random length sequences seems to help training mod_length = self.mod_length #+ random.randint(-1, 3) avg_length = self.avg_length #+ random.randint(-3, 5) for mol in text: #fill up buckets and caches if '\n' in mol['text']: print('carriage return in mol') raw_regex = self.regex.findall(mol['text'].strip('\n')) length = len(raw_regex) if length > self.min_length and length < mod_length: if len(self.bucket0) < self.b0_max: self.bucket0.append(raw_regex) else: self.b0_cache.append(raw_regex) elif length >= mod_length and length < avg_length: if len(self.bucket1) < self.b1_max: self.bucket1.append(raw_regex) else: self.b1_cache.append(raw_regex) elif length >= avg_length and length < self.tail: if len(self.bucket2) < self.b2_max: self.bucket2.append(raw_regex) else: self.b2_cache.append(raw_regex) elif length >= self.tail and length < self.max_length: if len(self.bucket3) < self.b3_max: self.bucket3.append(raw_regex) else: self.b3_cache.append(raw_regex) # elif length >= avg_length and length < self.tail: # self.b2_cache.append(raw_regex) # #if len(bucket2) < self.b2_max: # # bucket2.append(raw_regex) # #else: # # self.b2_cache.append(raw_regex) # elif length >= self.tail and length < self.max_length: # self.b3_cache.append(raw_regex) # #if len(bucket3) < self.b3_max: # # bucket3.append(raw_regex) # #else: # # self.b3_cache.append(raw_regex) #print('before Cache size {} {} {} {}'.format(len(self.b0_cache), len(self.b1_cache), len(self.b2_cache), len(self.b3_cache))) #pour cache elements into any open bucket if len(self.bucket0) < self.b0_max and len(self.b0_cache) > 0: cache_size = len(self.b0_cache) max_margin = self.b0_max-len(self.bucket0) range0 = min(cache_size, max_margin) outbucket0 = [self.bucket0.pop() for item in range(len(self.bucket0))] + [self.b0_cache.pop() for i in range(range0)] #self.b0_cache = collections.deque(self.b0_cache[:self.b0_max-len(bucket0)]) #print('0 type {}'.format(type(self.b0_cache))) else: outbucket0 = [self.bucket0.pop() for item in range(len(self.bucket0))] if len(self.bucket1) < self.b1_max and len(self.b1_cache) > 0: cache_size = len(self.b1_cache) max_margin = self.b1_max-len(self.bucket1) range1 = min(cache_size, max_margin) outbucket1 = [self.bucket1.pop() for item in range(len(self.bucket1))] + [self.b1_cache.pop() for i in range(range1)] else: outbucket1 = [self.bucket1.pop() for item in range(len(self.bucket1))] if len(self.bucket2) < self.b2_max and len(self.b2_cache) > 0: cache_size = len(self.b2_cache) max_margin = self.b2_max-len(self.bucket2) range2 = min(cache_size, max_margin) outbucket2 = [self.bucket2.pop() for item in range(len(self.bucket2))] + [self.b2_cache.pop() for i in range(range2)] else: outbucket2 = [self.bucket2.pop() for item in range(len(self.bucket2))] if len(self.bucket3) < self.b3_max and len(self.b3_cache) > 0: cache_size = len(self.b3_cache) max_margin = self.b3_max-len(self.bucket3) range3 = min(cache_size, max_margin) outbucket3 = [self.bucket3.pop() for item in range(len(self.bucket3))] + [self.b3_cache.pop() for i in range(range3)] else: outbucket3 = [self.bucket3.pop() for item in range(len(self.bucket3))] # if len(self.b2_cache) > self.b2_max: # cache_size = len(self.b2_cache) # max_margin = self.b2_max # range2 = min(cache_size, max_margin) # outbucket2 = [self.b2_cache.pop() for i in range(range2)] # else: # outbucket2=[] # if len(self.b3_cache) > self.b3_max: # cache_size = len(self.b3_cache) # max_margin = self.b3_max # range3 = min(cache_size, max_margin) # outbucket3 = [self.b3_cache.pop() for i in range(range3)] # else: # outbucket3 = [] return outbucket0, outbucket1, outbucket2, outbucket3 def mask_tokens( self, inputs, special_tokens_mask= None): """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ labels = inputs.clone() # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) probability_matrix = torch.full(labels.size(), self.mlm_probability) if special_tokens_mask is None: special_tokens_mask = [ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() ] special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) else: special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) #special_tokens_mask = special_tokens_mask.bool() #print(special_tokens_mask.size()) probability_matrix.masked_fill_(special_tokens_mask, value=0.0) masked_indices = torch.bernoulli(probability_matrix).bool() labels[~masked_indices] = -100 # We only compute loss on masked tokens # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) indices_replaced = torch.bernoulli(torch.full(labels.size(), 0.8)).bool() & masked_indices inputs[indices_replaced] = self.mask # 10% of the time, we replace masked input tokens with random word indices_random = torch.bernoulli(torch.full(labels.size(), 0.5)).bool() & masked_indices & ~indices_replaced random_words = torch.randint(len(self.char2id.keys()), labels.size(), dtype=torch.long) inputs[indices_random] = random_words[indices_random] # The rest of the time (10% of the time) we keep the masked input tokens unchanged return inputs, labels def pack_tensors(self, tokens): array_ids = self.encoder(tokens) array = torch.nn.utils.rnn.pad_sequence(array_ids, batch_first=True, padding_value=self.pad) lengths = (array!=self.pad).sum(dim=-1) #Bert tokenization special_token_mask = [list(map(lambda x: 1 if x in [self.bos, self.eos, self.pad] else 0, stuff)) for stuff in array.tolist()] masked_array, masked_labels = self.mask_tokens(array, special_token_mask) return masked_array, masked_labels, array_ids, lengths def process(self, text): arrays = [] lengths = [] targets = [] arrays_ids = [] for tokens in self.process_text(text): if len(tokens) > 0: array, target, array_ids, lgt = self.pack_tensors(tokens) arrays.append(array) targets.append(target) arrays_ids.append(array_ids) lengths.append(lgt) return arrays, targets, arrays_ids, lengths if __name__ == '__main__': text_encoder = Encoder()