Spaces:
Runtime error
Runtime error
Topallaj Denis
commited on
Commit
•
c7272f2
1
Parent(s):
959f4cc
copied the unikp model into this endpoint
Browse files- Kcat.pkl +3 -0
- Kcat_over_Km.pkl +3 -0
- Km.pkl +3 -0
- build_vocab.py +148 -0
- dataset.py +56 -0
- enumerator.py +223 -0
- main.py +190 -4
- pretrain_trfm.py +175 -0
- trfm_12_23000.pkl +3 -0
- utils.py +194 -0
- vocab.pkl +3 -0
- vocab_content.txt +45 -0
Kcat.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fe90811273401698a2c25ab32959f13c0087a14feb8ca310cf4b44dcad819fd5
|
3 |
+
size 205501172
|
Kcat_over_Km.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1f81f8715a87790c542023b1bab1da6055a60b9db22c20c04a8846d9b09ba844
|
3 |
+
size 11476980
|
Km.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f0b8d597cf4e5f73431980950cb89423415a260dbbf9be0bb1d8810712bf9c07
|
3 |
+
size 147957236
|
build_vocab.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from collections import Counter
|
3 |
+
|
4 |
+
class TorchVocab(object):
|
5 |
+
"""
|
6 |
+
:property freqs: collections.Counter, コーパス中の単語の出現頻度を保持するオブジェクト
|
7 |
+
:property stoi: collections.defaultdict, string → id の対応を示す辞書
|
8 |
+
:property itos: collections.defaultdict, id → string の対応を示す辞書
|
9 |
+
"""
|
10 |
+
def __init__(self, counter, max_size=None, min_freq=1, specials=['<pad>', '<oov>'],
|
11 |
+
vectors=None, unk_init=None, vectors_cache=None):
|
12 |
+
"""
|
13 |
+
:param counter: collections.Counter, データ中に含まれる単語の頻度を計測するためのcounter
|
14 |
+
:param max_size: int, vocabularyの最大のサイズ. Noneの場合は最大値なし. defaultはNone
|
15 |
+
:param min_freq: int, vocabulary中の単語の最低出現頻度. この数以下の出現回数の単語はvocabularyに加えられない.
|
16 |
+
:param specials: list of str, vocabularyにあらかじめ登録するtoken
|
17 |
+
:param vectors: list of vectors, 事前学習済みのベクトル. ex)Vocab.load_vectors
|
18 |
+
"""
|
19 |
+
self.freqs = counter
|
20 |
+
counter = counter.copy()
|
21 |
+
min_freq = max(min_freq, 1)
|
22 |
+
|
23 |
+
self.itos = list(specials)
|
24 |
+
# special tokensの出現頻度はvocabulary作成の際にカウントされない
|
25 |
+
for tok in specials:
|
26 |
+
del counter[tok]
|
27 |
+
|
28 |
+
max_size = None if max_size is None else max_size + len(self.itos)
|
29 |
+
|
30 |
+
# まず頻度でソートし、次に文字順で並び替える
|
31 |
+
words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
|
32 |
+
words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
|
33 |
+
|
34 |
+
# 出現頻度がmin_freq未満のものはvocabに加えない
|
35 |
+
for word, freq in words_and_frequencies:
|
36 |
+
if freq < min_freq or len(self.itos) == max_size:
|
37 |
+
break
|
38 |
+
self.itos.append(word)
|
39 |
+
|
40 |
+
# dictのk,vをいれかえてstoiを作成する
|
41 |
+
self.stoi = {tok: i for i, tok in enumerate(self.itos)}
|
42 |
+
|
43 |
+
self.vectors = None
|
44 |
+
if vectors is not None:
|
45 |
+
self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
|
46 |
+
else:
|
47 |
+
assert unk_init is None and vectors_cache is None
|
48 |
+
|
49 |
+
def __eq__(self, other):
|
50 |
+
if self.freqs != other.freqs:
|
51 |
+
return False
|
52 |
+
if self.stoi != other.stoi:
|
53 |
+
return False
|
54 |
+
if self.itos != other.itos:
|
55 |
+
return False
|
56 |
+
if self.vectors != other.vectors:
|
57 |
+
return False
|
58 |
+
return True
|
59 |
+
|
60 |
+
def __len__(self):
|
61 |
+
return len(self.itos)
|
62 |
+
|
63 |
+
def vocab_rerank(self):
|
64 |
+
self.stoi = {word: i for i, word in enumerate(self.itos)}
|
65 |
+
|
66 |
+
def extend(self, v, sort=False):
|
67 |
+
words = sorted(v.itos) if sort else v.itos
|
68 |
+
for w in words:
|
69 |
+
if w not in self.stoi:
|
70 |
+
self.itos.append(w)
|
71 |
+
self.stoi[w] = len(self.itos) - 1
|
72 |
+
|
73 |
+
|
74 |
+
class Vocab(TorchVocab):
|
75 |
+
def __init__(self, counter, max_size=None, min_freq=1):
|
76 |
+
self.pad_index = 0
|
77 |
+
self.unk_index = 1
|
78 |
+
self.eos_index = 2
|
79 |
+
self.sos_index = 3
|
80 |
+
self.mask_index = 4
|
81 |
+
super().__init__(counter, specials=["<pad>", "<unk>", "<eos>", "<sos>", "<mask>"], max_size=max_size, min_freq=min_freq)
|
82 |
+
|
83 |
+
# override用
|
84 |
+
def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list:
|
85 |
+
pass
|
86 |
+
|
87 |
+
# override用
|
88 |
+
def from_seq(self, seq, join=False, with_pad=False):
|
89 |
+
pass
|
90 |
+
|
91 |
+
def load_vocab(vocab_path: str) -> 'Vocab':
|
92 |
+
with open(vocab_path, "rb") as f:
|
93 |
+
return pickle.load(f)
|
94 |
+
|
95 |
+
def save_vocab(self, vocab_path):
|
96 |
+
with open(vocab_path, "wb") as f:
|
97 |
+
pickle.dump(self, f)
|
98 |
+
|
99 |
+
|
100 |
+
# テキストファイルからvocabを作成する
|
101 |
+
class WordVocab(Vocab):
|
102 |
+
def __init__(self, texts, max_size=None, min_freq=1):
|
103 |
+
print("Building Vocab")
|
104 |
+
counter = Counter()
|
105 |
+
for line in texts:
|
106 |
+
if isinstance(line, list):
|
107 |
+
words = line
|
108 |
+
else:
|
109 |
+
words = line.replace("\n", "").replace("\t", "").split()
|
110 |
+
|
111 |
+
for word in words:
|
112 |
+
counter[word] += 1
|
113 |
+
super().__init__(counter, max_size=max_size, min_freq=min_freq)
|
114 |
+
|
115 |
+
def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False):
|
116 |
+
if isinstance(sentence, str):
|
117 |
+
sentence = sentence.split()
|
118 |
+
|
119 |
+
seq = [self.stoi.get(word, self.unk_index) for word in sentence]
|
120 |
+
|
121 |
+
if with_eos:
|
122 |
+
seq += [self.eos_index] # this would be index 1
|
123 |
+
if with_sos:
|
124 |
+
seq = [self.sos_index] + seq
|
125 |
+
|
126 |
+
origin_seq_len = len(seq)
|
127 |
+
|
128 |
+
if seq_len is None:
|
129 |
+
pass
|
130 |
+
elif len(seq) <= seq_len:
|
131 |
+
seq += [self.pad_index for _ in range(seq_len - len(seq))]
|
132 |
+
else:
|
133 |
+
seq = seq[:seq_len]
|
134 |
+
|
135 |
+
return (seq, origin_seq_len) if with_len else seq
|
136 |
+
|
137 |
+
def from_seq(self, seq, join=False, with_pad=False):
|
138 |
+
words = [self.itos[idx]
|
139 |
+
if idx < len(self.itos)
|
140 |
+
else "<%d>" % idx
|
141 |
+
for idx in seq
|
142 |
+
if not with_pad or idx != self.pad_index]
|
143 |
+
|
144 |
+
return " ".join(words) if join else words
|
145 |
+
|
146 |
+
def load_vocab(vocab_path: str) -> 'WordVocab':
|
147 |
+
with open(vocab_path, "rb") as f:
|
148 |
+
return pickle.load(f)
|
dataset.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import pandas as pd
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset, DataLoader
|
5 |
+
|
6 |
+
from enumerator import SmilesEnumerator
|
7 |
+
from utils import split
|
8 |
+
|
9 |
+
PAD = 0
|
10 |
+
MAX_LEN = 220
|
11 |
+
|
12 |
+
class Randomizer(object):
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
self.sme = SmilesEnumerator()
|
16 |
+
|
17 |
+
def __call__(self, sm):
|
18 |
+
sm_r = self.sme.randomize_smiles(sm) # Random transoform
|
19 |
+
if sm_r is None:
|
20 |
+
sm_spaced = split(sm) # Spacing
|
21 |
+
else:
|
22 |
+
sm_spaced = split(sm_r) # Spacing
|
23 |
+
sm_split = sm_spaced.split()
|
24 |
+
if len(sm_split)<=MAX_LEN - 2:
|
25 |
+
return sm_split # List
|
26 |
+
else:
|
27 |
+
return split(sm).split()
|
28 |
+
|
29 |
+
def random_transform(self, sm):
|
30 |
+
'''
|
31 |
+
function: Random transformation for SMILES. It may take some time.
|
32 |
+
input: A SMILES
|
33 |
+
output: A randomized SMILES
|
34 |
+
'''
|
35 |
+
return self.sme.randomize_smiles(sm)
|
36 |
+
|
37 |
+
class Seq2seqDataset(Dataset):
|
38 |
+
|
39 |
+
def __init__(self, smiles, vocab, seq_len=220, transform=Randomizer()):
|
40 |
+
self.smiles = smiles
|
41 |
+
self.vocab = vocab
|
42 |
+
self.seq_len = seq_len
|
43 |
+
self.transform = transform
|
44 |
+
|
45 |
+
def __len__(self):
|
46 |
+
return len(self.smiles)
|
47 |
+
|
48 |
+
def __getitem__(self, item):
|
49 |
+
sm = self.smiles[item]
|
50 |
+
sm = self.transform(sm) # List
|
51 |
+
content = [self.vocab.stoi.get(token, self.vocab.unk_index) for token in sm]
|
52 |
+
X = [self.vocab.sos_index] + content + [self.vocab.eos_index]
|
53 |
+
padding = [self.vocab.pad_index]*(self.seq_len - len(X))
|
54 |
+
X.extend(padding)
|
55 |
+
return torch.tensor(X)
|
56 |
+
|
enumerator.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#Experimental Class for Smiles Enumeration, Iterator and SmilesIterator adapted from Keras 1.2.2
|
2 |
+
from rdkit import Chem
|
3 |
+
import numpy as np
|
4 |
+
import threading
|
5 |
+
|
6 |
+
class Iterator(object):
|
7 |
+
"""Abstract base class for data iterators.
|
8 |
+
|
9 |
+
# Arguments
|
10 |
+
n: Integer, total number of samples in the dataset to loop over.
|
11 |
+
batch_size: Integer, size of a batch.
|
12 |
+
shuffle: Boolean, whether to shuffle the data between epochs.
|
13 |
+
seed: Random seeding for data shuffling.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, n, batch_size, shuffle, seed):
|
17 |
+
self.n = n
|
18 |
+
self.batch_size = batch_size
|
19 |
+
self.shuffle = shuffle
|
20 |
+
self.batch_index = 0
|
21 |
+
self.total_batches_seen = 0
|
22 |
+
self.lock = threading.Lock()
|
23 |
+
self.index_generator = self._flow_index(n, batch_size, shuffle, seed)
|
24 |
+
if n < batch_size:
|
25 |
+
raise ValueError('Input data length is shorter than batch_size\nAdjust batch_size')
|
26 |
+
|
27 |
+
def reset(self):
|
28 |
+
self.batch_index = 0
|
29 |
+
|
30 |
+
def _flow_index(self, n, batch_size=32, shuffle=False, seed=None):
|
31 |
+
# Ensure self.batch_index is 0.
|
32 |
+
self.reset()
|
33 |
+
while 1:
|
34 |
+
if seed is not None:
|
35 |
+
np.random.seed(seed + self.total_batches_seen)
|
36 |
+
if self.batch_index == 0:
|
37 |
+
index_array = np.arange(n)
|
38 |
+
if shuffle:
|
39 |
+
index_array = np.random.permutation(n)
|
40 |
+
|
41 |
+
current_index = (self.batch_index * batch_size) % n
|
42 |
+
if n > current_index + batch_size:
|
43 |
+
current_batch_size = batch_size
|
44 |
+
self.batch_index += 1
|
45 |
+
else:
|
46 |
+
current_batch_size = n - current_index
|
47 |
+
self.batch_index = 0
|
48 |
+
self.total_batches_seen += 1
|
49 |
+
yield (index_array[current_index: current_index + current_batch_size],
|
50 |
+
current_index, current_batch_size)
|
51 |
+
|
52 |
+
def __iter__(self):
|
53 |
+
# Needed if we want to do something like:
|
54 |
+
# for x, y in data_gen.flow(...):
|
55 |
+
return self
|
56 |
+
|
57 |
+
def __next__(self, *args, **kwargs):
|
58 |
+
return self.next(*args, **kwargs)
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
class SmilesIterator(Iterator):
|
64 |
+
"""Iterator yielding data from a SMILES array.
|
65 |
+
|
66 |
+
# Arguments
|
67 |
+
x: Numpy array of SMILES input data.
|
68 |
+
y: Numpy array of targets data.
|
69 |
+
smiles_data_generator: Instance of `SmilesEnumerator`
|
70 |
+
to use for random SMILES generation.
|
71 |
+
batch_size: Integer, size of a batch.
|
72 |
+
shuffle: Boolean, whether to shuffle the data between epochs.
|
73 |
+
seed: Random seed for data shuffling.
|
74 |
+
dtype: dtype to use for returned batch. Set to keras.backend.floatx if using Keras
|
75 |
+
"""
|
76 |
+
|
77 |
+
def __init__(self, x, y, smiles_data_generator,
|
78 |
+
batch_size=32, shuffle=False, seed=None,
|
79 |
+
dtype=np.float32
|
80 |
+
):
|
81 |
+
if y is not None and len(x) != len(y):
|
82 |
+
raise ValueError('X (images tensor) and y (labels) '
|
83 |
+
'should have the same length. '
|
84 |
+
'Found: X.shape = %s, y.shape = %s' %
|
85 |
+
(np.asarray(x).shape, np.asarray(y).shape))
|
86 |
+
|
87 |
+
self.x = np.asarray(x)
|
88 |
+
|
89 |
+
if y is not None:
|
90 |
+
self.y = np.asarray(y)
|
91 |
+
else:
|
92 |
+
self.y = None
|
93 |
+
self.smiles_data_generator = smiles_data_generator
|
94 |
+
self.dtype = dtype
|
95 |
+
super(SmilesIterator, self).__init__(x.shape[0], batch_size, shuffle, seed)
|
96 |
+
|
97 |
+
def next(self):
|
98 |
+
"""For python 2.x.
|
99 |
+
|
100 |
+
# Returns
|
101 |
+
The next batch.
|
102 |
+
"""
|
103 |
+
# Keeps under lock only the mechanism which advances
|
104 |
+
# the indexing of each batch.
|
105 |
+
with self.lock:
|
106 |
+
index_array, current_index, current_batch_size = next(self.index_generator)
|
107 |
+
# The transformation of images is not under thread lock
|
108 |
+
# so it can be done in parallel
|
109 |
+
batch_x = np.zeros(tuple([current_batch_size] + [ self.smiles_data_generator.pad, self.smiles_data_generator._charlen]), dtype=self.dtype)
|
110 |
+
for i, j in enumerate(index_array):
|
111 |
+
smiles = self.x[j:j+1]
|
112 |
+
x = self.smiles_data_generator.transform(smiles)
|
113 |
+
batch_x[i] = x
|
114 |
+
|
115 |
+
if self.y is None:
|
116 |
+
return batch_x
|
117 |
+
batch_y = self.y[index_array]
|
118 |
+
return batch_x, batch_y
|
119 |
+
|
120 |
+
|
121 |
+
class SmilesEnumerator(object):
|
122 |
+
"""SMILES Enumerator, vectorizer and devectorizer
|
123 |
+
|
124 |
+
#Arguments
|
125 |
+
charset: string containing the characters for the vectorization
|
126 |
+
can also be generated via the .fit() method
|
127 |
+
pad: Length of the vectorization
|
128 |
+
leftpad: Add spaces to the left of the SMILES
|
129 |
+
isomericSmiles: Generate SMILES containing information about stereogenic centers
|
130 |
+
enum: Enumerate the SMILES during transform
|
131 |
+
canonical: use canonical SMILES during transform (overrides enum)
|
132 |
+
"""
|
133 |
+
def __init__(self, charset = '@C)(=cOn1S2/H[N]\\', pad=120, leftpad=True, isomericSmiles=True, enum=True, canonical=False):
|
134 |
+
self._charset = None
|
135 |
+
self.charset = charset
|
136 |
+
self.pad = pad
|
137 |
+
self.leftpad = leftpad
|
138 |
+
self.isomericSmiles = isomericSmiles
|
139 |
+
self.enumerate = enum
|
140 |
+
self.canonical = canonical
|
141 |
+
|
142 |
+
@property
|
143 |
+
def charset(self):
|
144 |
+
return self._charset
|
145 |
+
|
146 |
+
@charset.setter
|
147 |
+
def charset(self, charset):
|
148 |
+
self._charset = charset
|
149 |
+
self._charlen = len(charset)
|
150 |
+
self._char_to_int = dict((c,i) for i,c in enumerate(charset))
|
151 |
+
self._int_to_char = dict((i,c) for i,c in enumerate(charset))
|
152 |
+
|
153 |
+
def fit(self, smiles, extra_chars=[], extra_pad = 5):
|
154 |
+
"""Performs extraction of the charset and length of a SMILES datasets and sets self.pad and self.charset
|
155 |
+
|
156 |
+
#Arguments
|
157 |
+
smiles: Numpy array or Pandas series containing smiles as strings
|
158 |
+
extra_chars: List of extra chars to add to the charset (e.g. "\\\\" when "/" is present)
|
159 |
+
extra_pad: Extra padding to add before or after the SMILES vectorization
|
160 |
+
"""
|
161 |
+
charset = set("".join(list(smiles)))
|
162 |
+
#print(charset)
|
163 |
+
self.charset = "".join(charset.union(set(extra_chars)))
|
164 |
+
#print(self.charset)
|
165 |
+
self.pad = max([len(smile) for smile in smiles]) + extra_pad
|
166 |
+
|
167 |
+
def randomize_smiles(self, smiles):
|
168 |
+
"""Perform a randomization of a SMILES string
|
169 |
+
must be RDKit sanitizable"""
|
170 |
+
m = Chem.MolFromSmiles(smiles)
|
171 |
+
if m is None:
|
172 |
+
return None # Invalid SMILES
|
173 |
+
ans = list(range(m.GetNumAtoms()))
|
174 |
+
np.random.shuffle(ans)
|
175 |
+
nm = Chem.RenumberAtoms(m,ans)
|
176 |
+
return Chem.MolToSmiles(nm, canonical=self.canonical, isomericSmiles=self.isomericSmiles)
|
177 |
+
|
178 |
+
def transform(self, smiles):
|
179 |
+
"""Perform an enumeration (randomization) and vectorization of a Numpy array of smiles strings
|
180 |
+
#Arguments
|
181 |
+
smiles: Numpy array or Pandas series containing smiles as strings
|
182 |
+
"""
|
183 |
+
one_hot = np.zeros((smiles.shape[0], self.pad, self._charlen),dtype=np.int8)
|
184 |
+
|
185 |
+
if self.leftpad:
|
186 |
+
#print(smiles)
|
187 |
+
for i,ss in enumerate(smiles):
|
188 |
+
if self.enumerate:
|
189 |
+
ss = self.randomize_smiles(ss)
|
190 |
+
l = len(ss)
|
191 |
+
#print("???", ss)
|
192 |
+
diff = self.pad - l
|
193 |
+
for j,c in enumerate(ss):
|
194 |
+
one_hot[i,j+diff,self._char_to_int[c]] = 1
|
195 |
+
return one_hot
|
196 |
+
else:
|
197 |
+
for i,ss in enumerate(smiles):
|
198 |
+
if self.enumerate:
|
199 |
+
ss = self.randomize_smiles(ss)
|
200 |
+
for j,c in enumerate(ss):
|
201 |
+
one_hot[i,j,self._char_to_int[c]] = 1
|
202 |
+
return one_hot
|
203 |
+
|
204 |
+
|
205 |
+
def reverse_transform(self, vect):
|
206 |
+
""" Performs a conversion of a vectorized SMILES to a smiles strings
|
207 |
+
charset must be the same as used for vectorization.
|
208 |
+
#Arguments
|
209 |
+
vect: Numpy array of vectorized SMILES.
|
210 |
+
"""
|
211 |
+
smiles = []
|
212 |
+
for v in vect:
|
213 |
+
#mask v
|
214 |
+
v=v[v.sum(axis=1)==1]
|
215 |
+
#Find one hot encoded index with argmax, translate to char and join to string
|
216 |
+
smile = "".join(self._int_to_char[i] for i in v.argmax(axis=1))
|
217 |
+
smiles.append(smile)
|
218 |
+
return np.array(smiles)
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
|
main.py
CHANGED
@@ -1,6 +1,16 @@
|
|
1 |
from fastapi import FastAPI
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
app = FastAPI()
|
6 |
|
@@ -12,6 +22,182 @@ app.add_middleware(
|
|
12 |
allow_headers=["*"]
|
13 |
)
|
14 |
|
15 |
-
@app.get("/")
|
16 |
-
def
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from fastapi import FastAPI
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
from typing import Dict, List, Any, Tuple
|
4 |
+
import pickle
|
5 |
+
import math
|
6 |
+
import re
|
7 |
+
import gc
|
8 |
+
from utils import split
|
9 |
+
import torch
|
10 |
+
from build_vocab import WordVocab
|
11 |
+
from pretrain_trfm import TrfmSeq2seq
|
12 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
13 |
+
import numpy as np
|
14 |
|
15 |
app = FastAPI()
|
16 |
|
|
|
22 |
allow_headers=["*"]
|
23 |
)
|
24 |
|
25 |
+
@app.get("/predict")
|
26 |
+
def predict_UniKP_values(
|
27 |
+
sequence: str,
|
28 |
+
smiles: str
|
29 |
+
):
|
30 |
+
endpointHandler = EndpointHandler()
|
31 |
+
result = endpointHandler.predict({
|
32 |
+
"inputs": {
|
33 |
+
"sequence": sequence,
|
34 |
+
"smiles": smiles
|
35 |
+
}
|
36 |
+
})
|
37 |
+
|
38 |
+
return result
|
39 |
+
|
40 |
+
|
41 |
+
class EndpointHandler():
|
42 |
+
def __init__(self, path=""):
|
43 |
+
|
44 |
+
# load tokenizer and model
|
45 |
+
self.tokenizer = T5Tokenizer.from_pretrained(
|
46 |
+
"Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False, torch_dtype=torch.float16)
|
47 |
+
self.model = T5EncoderModel.from_pretrained(
|
48 |
+
"Rostlab/prot_t5_xl_half_uniref50-enc")
|
49 |
+
|
50 |
+
# path to the vocab_content and trfm model
|
51 |
+
vocab_content_path = f"{path}/vocab_content.txt"
|
52 |
+
trfm_path = f"{path}/trfm_12_23000.pkl"
|
53 |
+
|
54 |
+
# load the vocab_content instead of the pickle file
|
55 |
+
with open(vocab_content_path, "r", encoding="utf-8") as f:
|
56 |
+
vocab_content = f.read().strip().split("\n")
|
57 |
+
|
58 |
+
# load the vocab and trfm model
|
59 |
+
self.vocab = WordVocab(vocab_content)
|
60 |
+
self.trfm = TrfmSeq2seq(len(self.vocab), 256, len(self.vocab), 4)
|
61 |
+
self.trfm.load_state_dict(torch.load(trfm_path))
|
62 |
+
self.trfm.eval()
|
63 |
+
|
64 |
+
# path to the pretrained models
|
65 |
+
self.Km_model_path = f"{path}/Km.pkl"
|
66 |
+
self.Kcat_model_path = f"{path}/Kcat.pkl"
|
67 |
+
self.Kcat_over_Km_model_path = f"{path}/Kcat_over_Km.pkl"
|
68 |
+
|
69 |
+
# vocab indices
|
70 |
+
self.pad_index = 0
|
71 |
+
self.unk_index = 1
|
72 |
+
self.eos_index = 2
|
73 |
+
self.sos_index = 3
|
74 |
+
|
75 |
+
def predict(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
76 |
+
"""
|
77 |
+
Function where the endpoint logic is implemented.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
data (Dict[str, Any]): The input data for the endpoint. It only contain a single key "inputs" which is a list of dictionaries. The dictionary contains the following keys:
|
81 |
+
|
82 |
+
- sequence (str): Amino acid sequence.
|
83 |
+
- smiles (str): SMILES representation of the molecule.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
Dict[str, Any]: The output data for the endpoint. The dictionary contains the following keys:
|
87 |
+
|
88 |
+
- Km (float): float of predicted Km value.
|
89 |
+
- Kcat (float): float of predicted Kcat value.
|
90 |
+
- Vmax (float): float of predicted Vmax value.
|
91 |
+
"""
|
92 |
+
|
93 |
+
sequence = data["inputs"]["sequence"]
|
94 |
+
smiles = data["inputs"]["smiles"]
|
95 |
+
|
96 |
+
seq_vec = self.Seq_to_vec(sequence)
|
97 |
+
smiles_vec = self.smiles_to_vec(smiles)
|
98 |
+
|
99 |
+
fused_vector = np.concatenate((smiles_vec, seq_vec), axis=1)
|
100 |
+
|
101 |
+
pred_Km = self.predict_feature_using_model(
|
102 |
+
fused_vector, self.Km_model_path)
|
103 |
+
pred_Kcat = self.predict_feature_using_model(
|
104 |
+
fused_vector, self.Kcat_model_path)
|
105 |
+
pred_Vmax = self.predict_feature_using_model(
|
106 |
+
fused_vector, self.Kcat_over_Km_model_path)
|
107 |
+
|
108 |
+
result = {
|
109 |
+
"Km": pred_Km,
|
110 |
+
"Kcat": pred_Kcat,
|
111 |
+
"Vmax": pred_Vmax,
|
112 |
+
}
|
113 |
+
|
114 |
+
return result
|
115 |
+
|
116 |
+
def predict_feature_using_model(self, X: np.array, model_path: str) -> float:
|
117 |
+
"""
|
118 |
+
Function to predict the feature using the pretrained model.
|
119 |
+
"""
|
120 |
+
with open(model_path, "rb") as f:
|
121 |
+
model = pickle.load(f)
|
122 |
+
pred_feature = model.predict(X)
|
123 |
+
pred_feature_pow = math.pow(10, pred_feature)
|
124 |
+
return pred_feature_pow
|
125 |
+
|
126 |
+
def smiles_to_vec(self, Smiles: str) -> np.array:
|
127 |
+
"""
|
128 |
+
Function to convert the smiles to a vector using the pretrained model.
|
129 |
+
"""
|
130 |
+
Smiles = [Smiles]
|
131 |
+
|
132 |
+
x_split = [split(sm) for sm in Smiles]
|
133 |
+
xid, xseg = self.get_array(x_split, self.vocab)
|
134 |
+
X = self.trfm.encode(torch.t(xid))
|
135 |
+
return X
|
136 |
+
|
137 |
+
def get_inputs(self, sm: str, vocab: WordVocab) -> Tuple[List[int], List[int]]:
|
138 |
+
"""
|
139 |
+
Convert smiles to tensor
|
140 |
+
"""
|
141 |
+
seq_len = len(sm)
|
142 |
+
sm = sm.split()
|
143 |
+
ids = [vocab.stoi.get(token, self.unk_index) for token in sm]
|
144 |
+
ids = [self.sos_index] + ids + [self.eos_index]
|
145 |
+
seg = [1]*len(ids)
|
146 |
+
padding = [self.pad_index]*(seq_len - len(ids))
|
147 |
+
ids.extend(padding), seg.extend(padding)
|
148 |
+
return ids, seg
|
149 |
+
|
150 |
+
def get_array(self, smiles: list[str], vocab: WordVocab) -> Tuple[torch.tensor, torch.tensor]:
|
151 |
+
"""
|
152 |
+
Convert smiles to tensor
|
153 |
+
"""
|
154 |
+
x_id, x_seg = [], []
|
155 |
+
for sm in smiles:
|
156 |
+
a,b = self.get_inputs(sm, vocab)
|
157 |
+
x_id.append(a)
|
158 |
+
x_seg.append(b)
|
159 |
+
return torch.tensor(x_id), torch.tensor(x_seg)
|
160 |
+
|
161 |
+
def Seq_to_vec(self, Sequence: str) -> np.array:
|
162 |
+
"""
|
163 |
+
Function to convert the sequence to a vector using the pretrained model.
|
164 |
+
"""
|
165 |
+
|
166 |
+
Sequence = [Sequence]
|
167 |
+
sequences_Example = []
|
168 |
+
for i in range(len(Sequence)):
|
169 |
+
zj = ''
|
170 |
+
for j in range(len(Sequence[i]) - 1):
|
171 |
+
zj += Sequence[i][j] + ' '
|
172 |
+
zj += Sequence[i][-1]
|
173 |
+
sequences_Example.append(zj)
|
174 |
+
|
175 |
+
gc.collect()
|
176 |
+
print(torch.cuda.is_available())
|
177 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
178 |
+
|
179 |
+
self.model = self.model.to(device)
|
180 |
+
self.model = self.model.eval()
|
181 |
+
|
182 |
+
features = []
|
183 |
+
for i in range(len(sequences_Example)):
|
184 |
+
sequences_Example_i = sequences_Example[i]
|
185 |
+
sequences_Example_i = [re.sub(r"[UZOB]", "X", sequences_Example_i)]
|
186 |
+
ids = self.tokenizer.batch_encode_plus(sequences_Example_i, add_special_tokens=True, padding=True)
|
187 |
+
input_ids = torch.tensor(ids['input_ids']).to(device)
|
188 |
+
attention_mask = torch.tensor(ids['attention_mask']).to(device)
|
189 |
+
with torch.no_grad():
|
190 |
+
embedding = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
191 |
+
embedding = embedding.last_hidden_state.cpu().numpy()
|
192 |
+
for seq_num in range(len(embedding)):
|
193 |
+
seq_len = (attention_mask[seq_num] == 1).sum()
|
194 |
+
seq_emd = embedding[seq_num][:seq_len - 1]
|
195 |
+
features.append(seq_emd)
|
196 |
+
|
197 |
+
features_normalize = np.zeros([len(features), len(features[0][0])], dtype=float)
|
198 |
+
for i in range(len(features)):
|
199 |
+
for k in range(len(features[0][0])):
|
200 |
+
for j in range(len(features[i])):
|
201 |
+
features_normalize[i][k] += features[i][j][k]
|
202 |
+
features_normalize[i][k] /= len(features[i])
|
203 |
+
return features_normalize
|
pretrain_trfm.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch import optim
|
10 |
+
from torch.autograd import Variable
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from build_vocab import WordVocab
|
16 |
+
from dataset import Seq2seqDataset
|
17 |
+
|
18 |
+
PAD = 0
|
19 |
+
UNK = 1
|
20 |
+
EOS = 2
|
21 |
+
SOS = 3
|
22 |
+
MASK = 4
|
23 |
+
|
24 |
+
class PositionalEncoding(nn.Module):
|
25 |
+
"Implement the PE function. No batch support?"
|
26 |
+
def __init__(self, d_model, dropout, max_len=5000):
|
27 |
+
super(PositionalEncoding, self).__init__()
|
28 |
+
self.dropout = nn.Dropout(p=dropout)
|
29 |
+
|
30 |
+
# Compute the positional encodings once in log space.
|
31 |
+
pe = torch.zeros(max_len, d_model) # (T,H)
|
32 |
+
position = torch.arange(0., max_len).unsqueeze(1)
|
33 |
+
div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
|
34 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
35 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
36 |
+
pe = pe.unsqueeze(0)
|
37 |
+
self.register_buffer('pe', pe)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
x = x + Variable(self.pe[:, :x.size(1)],
|
41 |
+
requires_grad=False)
|
42 |
+
return self.dropout(x)
|
43 |
+
|
44 |
+
class TrfmSeq2seq(nn.Module):
|
45 |
+
def __init__(self, in_size, hidden_size, out_size, n_layers, dropout=0.1):
|
46 |
+
super(TrfmSeq2seq, self).__init__()
|
47 |
+
self.in_size = in_size
|
48 |
+
self.hidden_size = hidden_size
|
49 |
+
self.embed = nn.Embedding(in_size, hidden_size)
|
50 |
+
self.pe = PositionalEncoding(hidden_size, dropout)
|
51 |
+
self.trfm = nn.Transformer(d_model=hidden_size, nhead=4,
|
52 |
+
num_encoder_layers=n_layers, num_decoder_layers=n_layers, dim_feedforward=hidden_size)
|
53 |
+
self.out = nn.Linear(hidden_size, out_size)
|
54 |
+
|
55 |
+
def forward(self, src):
|
56 |
+
# src: (T,B)
|
57 |
+
embedded = self.embed(src) # (T,B,H)
|
58 |
+
embedded = self.pe(embedded) # (T,B,H)
|
59 |
+
hidden = self.trfm(embedded, embedded) # (T,B,H)
|
60 |
+
out = self.out(hidden) # (T,B,V)
|
61 |
+
out = F.log_softmax(out, dim=2) # (T,B,V)
|
62 |
+
return out # (T,B,V)
|
63 |
+
|
64 |
+
def _encode(self, src):
|
65 |
+
# src: (T,B)
|
66 |
+
embedded = self.embed(src) # (T,B,H)
|
67 |
+
embedded = self.pe(embedded) # (T,B,H)
|
68 |
+
output = embedded
|
69 |
+
for i in range(self.trfm.encoder.num_layers - 1):
|
70 |
+
output = self.trfm.encoder.layers[i](output, None) # (T,B,H)
|
71 |
+
penul = output.detach().numpy()
|
72 |
+
output = self.trfm.encoder.layers[-1](output, None) # (T,B,H)
|
73 |
+
if self.trfm.encoder.norm:
|
74 |
+
output = self.trfm.encoder.norm(output) # (T,B,H)
|
75 |
+
output = output.detach().numpy()
|
76 |
+
# mean, max, first*2
|
77 |
+
return np.hstack([np.mean(output, axis=0), np.max(output, axis=0), output[0,:,:], penul[0,:,:] ]) # (B,4H)
|
78 |
+
|
79 |
+
def encode(self, src):
|
80 |
+
# src: (T,B)
|
81 |
+
batch_size = src.shape[1]
|
82 |
+
if batch_size<=100:
|
83 |
+
return self._encode(src)
|
84 |
+
else: # Batch is too large to load
|
85 |
+
print('There are {:d} molecules. It will take a little time.'.format(batch_size))
|
86 |
+
st,ed = 0,100
|
87 |
+
out = self._encode(src[:,st:ed]) # (B,4H)
|
88 |
+
while ed<batch_size:
|
89 |
+
st += 100
|
90 |
+
ed += 100
|
91 |
+
out = np.concatenate([out, self._encode(src[:,st:ed])], axis=0)
|
92 |
+
return out
|
93 |
+
|
94 |
+
def parse_arguments():
|
95 |
+
parser = argparse.ArgumentParser(description='Hyperparams')
|
96 |
+
parser.add_argument('--n_epoch', '-e', type=int, default=5, help='number of epochs')
|
97 |
+
parser.add_argument('--vocab', '-v', type=str, default='data/vocab.pkl', help='vocabulary (.pkl)')
|
98 |
+
parser.add_argument('--data', '-d', type=str, default='data/chembl_25.csv', help='train corpus (.csv)')
|
99 |
+
parser.add_argument('--out-dir', '-o', type=str, default='../result', help='output directory')
|
100 |
+
parser.add_argument('--name', '-n', type=str, default='ST', help='model name')
|
101 |
+
parser.add_argument('--seq_len', type=int, default=220, help='maximum length of the paired seqence')
|
102 |
+
parser.add_argument('--batch_size', '-b', type=int, default=8, help='batch size')
|
103 |
+
parser.add_argument('--n_worker', '-w', type=int, default=16, help='number of workers')
|
104 |
+
parser.add_argument('--hidden', type=int, default=256, help='length of hidden vector')
|
105 |
+
parser.add_argument('--n_layer', '-l', type=int, default=4, help='number of layers')
|
106 |
+
parser.add_argument('--n_head', type=int, default=4, help='number of attention heads')
|
107 |
+
parser.add_argument('--lr', type=float, default=1e-4, help='Adam learning rate')
|
108 |
+
parser.add_argument('--gpu', metavar='N', type=int, nargs='+', help='list of GPU IDs to use')
|
109 |
+
return parser.parse_args()
|
110 |
+
|
111 |
+
|
112 |
+
def evaluate(model, test_loader, vocab):
|
113 |
+
model.eval()
|
114 |
+
total_loss = 0
|
115 |
+
for b, sm in enumerate(test_loader):
|
116 |
+
sm = torch.t(sm.cuda()) # (T,B)
|
117 |
+
with torch.no_grad():
|
118 |
+
output = model(sm) # (T,B,V)
|
119 |
+
loss = F.nll_loss(output.view(-1, len(vocab)),
|
120 |
+
sm.contiguous().view(-1),
|
121 |
+
ignore_index=PAD)
|
122 |
+
total_loss += loss.item()
|
123 |
+
return total_loss / len(test_loader)
|
124 |
+
|
125 |
+
def main():
|
126 |
+
args = parse_arguments()
|
127 |
+
assert torch.cuda.is_available()
|
128 |
+
|
129 |
+
print('Loading dataset...')
|
130 |
+
vocab = WordVocab.load_vocab(args.vocab)
|
131 |
+
dataset = Seq2seqDataset(pd.read_csv(args.data)['canonical_smiles'].values, vocab)
|
132 |
+
test_size = 10000
|
133 |
+
train, test = torch.utils.data.random_split(dataset, [len(dataset)-test_size, test_size])
|
134 |
+
train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, num_workers=args.n_worker)
|
135 |
+
test_loader = DataLoader(test, batch_size=args.batch_size, shuffle=False, num_workers=args.n_worker)
|
136 |
+
print('Train size:', len(train))
|
137 |
+
print('Test size:', len(test))
|
138 |
+
del dataset, train, test
|
139 |
+
|
140 |
+
model = TrfmSeq2seq(len(vocab), args.hidden, len(vocab), args.n_layer).cuda()
|
141 |
+
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
142 |
+
print(model)
|
143 |
+
print('Total parameters:', sum(p.numel() for p in model.parameters()))
|
144 |
+
|
145 |
+
best_loss = None
|
146 |
+
for e in range(1, args.n_epoch):
|
147 |
+
for b, sm in tqdm(enumerate(train_loader)):
|
148 |
+
sm = torch.t(sm.cuda()) # (T,B)
|
149 |
+
optimizer.zero_grad()
|
150 |
+
output = model(sm) # (T,B,V)
|
151 |
+
loss = F.nll_loss(output.view(-1, len(vocab)),
|
152 |
+
sm.contiguous().view(-1), ignore_index=PAD)
|
153 |
+
loss.backward()
|
154 |
+
optimizer.step()
|
155 |
+
if b%1000==0:
|
156 |
+
print('Train {:3d}: iter {:5d} | loss {:.3f} | ppl {:.3f}'.format(e, b, loss.item(), math.exp(loss.item())))
|
157 |
+
if b%10000==0:
|
158 |
+
loss = evaluate(model, test_loader, vocab)
|
159 |
+
print('Val {:3d}: iter {:5d} | loss {:.3f} | ppl {:.3f}'.format(e, b, loss, math.exp(loss)))
|
160 |
+
# Save the model if the validation loss is the best we've seen so far.
|
161 |
+
if not best_loss or loss < best_loss:
|
162 |
+
print("[!] saving model...")
|
163 |
+
if not os.path.isdir(".save"):
|
164 |
+
os.makedirs(".save")
|
165 |
+
torch.save(model.state_dict(), './.save/trfm_new_%d_%d.pkl' % (e,b))
|
166 |
+
best_loss = loss
|
167 |
+
|
168 |
+
|
169 |
+
if __name__ == "__main__":
|
170 |
+
try:
|
171 |
+
main()
|
172 |
+
except KeyboardInterrupt as e:
|
173 |
+
print("[STOP]", e)
|
174 |
+
|
175 |
+
|
trfm_12_23000.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6b56c8c05d048e7c7d143c4e3ba2bc6f76e5eda2358798cf636210406a700eb2
|
3 |
+
size 22128521
|
utils.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import torch.nn as nn
|
4 |
+
from rdkit import Chem
|
5 |
+
from rdkit import rdBase
|
6 |
+
rdBase.DisableLog('rdApp.*')
|
7 |
+
|
8 |
+
# Split SMILES into words
|
9 |
+
def split(sm):
|
10 |
+
'''
|
11 |
+
function: Split SMILES into words. Care for Cl, Br, Si, Se, Na etc.
|
12 |
+
input: A SMILES
|
13 |
+
output: A string with space between words
|
14 |
+
'''
|
15 |
+
arr = []
|
16 |
+
i = 0
|
17 |
+
while i < len(sm)-1:
|
18 |
+
if not sm[i] in ['%', 'C', 'B', 'S', 'N', 'R', 'X', 'L', 'A', 'M', \
|
19 |
+
'T', 'Z', 's', 't', 'H', '+', '-', 'K', 'F']:
|
20 |
+
arr.append(sm[i])
|
21 |
+
i += 1
|
22 |
+
elif sm[i]=='%':
|
23 |
+
arr.append(sm[i:i+3])
|
24 |
+
i += 3
|
25 |
+
elif sm[i]=='C' and sm[i+1]=='l':
|
26 |
+
arr.append(sm[i:i+2])
|
27 |
+
i += 2
|
28 |
+
elif sm[i]=='C' and sm[i+1]=='a':
|
29 |
+
arr.append(sm[i:i+2])
|
30 |
+
i += 2
|
31 |
+
elif sm[i]=='C' and sm[i+1]=='u':
|
32 |
+
arr.append(sm[i:i+2])
|
33 |
+
i += 2
|
34 |
+
elif sm[i]=='B' and sm[i+1]=='r':
|
35 |
+
arr.append(sm[i:i+2])
|
36 |
+
i += 2
|
37 |
+
elif sm[i]=='B' and sm[i+1]=='e':
|
38 |
+
arr.append(sm[i:i+2])
|
39 |
+
i += 2
|
40 |
+
elif sm[i]=='B' and sm[i+1]=='a':
|
41 |
+
arr.append(sm[i:i+2])
|
42 |
+
i += 2
|
43 |
+
elif sm[i]=='B' and sm[i+1]=='i':
|
44 |
+
arr.append(sm[i:i+2])
|
45 |
+
i += 2
|
46 |
+
elif sm[i]=='S' and sm[i+1]=='i':
|
47 |
+
arr.append(sm[i:i+2])
|
48 |
+
i += 2
|
49 |
+
elif sm[i]=='S' and sm[i+1]=='e':
|
50 |
+
arr.append(sm[i:i+2])
|
51 |
+
i += 2
|
52 |
+
elif sm[i]=='S' and sm[i+1]=='r':
|
53 |
+
arr.append(sm[i:i+2])
|
54 |
+
i += 2
|
55 |
+
elif sm[i]=='N' and sm[i+1]=='a':
|
56 |
+
arr.append(sm[i:i+2])
|
57 |
+
i += 2
|
58 |
+
elif sm[i]=='N' and sm[i+1]=='i':
|
59 |
+
arr.append(sm[i:i+2])
|
60 |
+
i += 2
|
61 |
+
elif sm[i]=='R' and sm[i+1]=='b':
|
62 |
+
arr.append(sm[i:i+2])
|
63 |
+
i += 2
|
64 |
+
elif sm[i]=='R' and sm[i+1]=='a':
|
65 |
+
arr.append(sm[i:i+2])
|
66 |
+
i += 2
|
67 |
+
elif sm[i]=='X' and sm[i+1]=='e':
|
68 |
+
arr.append(sm[i:i+2])
|
69 |
+
i += 2
|
70 |
+
elif sm[i]=='L' and sm[i+1]=='i':
|
71 |
+
arr.append(sm[i:i+2])
|
72 |
+
i += 2
|
73 |
+
elif sm[i]=='A' and sm[i+1]=='l':
|
74 |
+
arr.append(sm[i:i+2])
|
75 |
+
i += 2
|
76 |
+
elif sm[i]=='A' and sm[i+1]=='s':
|
77 |
+
arr.append(sm[i:i+2])
|
78 |
+
i += 2
|
79 |
+
elif sm[i]=='A' and sm[i+1]=='g':
|
80 |
+
arr.append(sm[i:i+2])
|
81 |
+
i += 2
|
82 |
+
elif sm[i]=='A' and sm[i+1]=='u':
|
83 |
+
arr.append(sm[i:i+2])
|
84 |
+
i += 2
|
85 |
+
elif sm[i]=='M' and sm[i+1]=='g':
|
86 |
+
arr.append(sm[i:i+2])
|
87 |
+
i += 2
|
88 |
+
elif sm[i]=='M' and sm[i+1]=='n':
|
89 |
+
arr.append(sm[i:i+2])
|
90 |
+
i += 2
|
91 |
+
elif sm[i]=='T' and sm[i+1]=='e':
|
92 |
+
arr.append(sm[i:i+2])
|
93 |
+
i += 2
|
94 |
+
elif sm[i]=='Z' and sm[i+1]=='n':
|
95 |
+
arr.append(sm[i:i+2])
|
96 |
+
i += 2
|
97 |
+
elif sm[i]=='s' and sm[i+1]=='i':
|
98 |
+
arr.append(sm[i:i+2])
|
99 |
+
i += 2
|
100 |
+
elif sm[i]=='s' and sm[i+1]=='e':
|
101 |
+
arr.append(sm[i:i+2])
|
102 |
+
i += 2
|
103 |
+
elif sm[i]=='t' and sm[i+1]=='e':
|
104 |
+
arr.append(sm[i:i+2])
|
105 |
+
i += 2
|
106 |
+
elif sm[i]=='H' and sm[i+1]=='e':
|
107 |
+
arr.append(sm[i:i+2])
|
108 |
+
i += 2
|
109 |
+
elif sm[i]=='+' and sm[i+1]=='2':
|
110 |
+
arr.append(sm[i:i+2])
|
111 |
+
i += 2
|
112 |
+
elif sm[i]=='+' and sm[i+1]=='3':
|
113 |
+
arr.append(sm[i:i+2])
|
114 |
+
i += 2
|
115 |
+
elif sm[i]=='+' and sm[i+1]=='4':
|
116 |
+
arr.append(sm[i:i+2])
|
117 |
+
i += 2
|
118 |
+
elif sm[i]=='-' and sm[i+1]=='2':
|
119 |
+
arr.append(sm[i:i+2])
|
120 |
+
i += 2
|
121 |
+
elif sm[i]=='-' and sm[i+1]=='3':
|
122 |
+
arr.append(sm[i:i+2])
|
123 |
+
i += 2
|
124 |
+
elif sm[i]=='-' and sm[i+1]=='4':
|
125 |
+
arr.append(sm[i:i+2])
|
126 |
+
i += 2
|
127 |
+
elif sm[i]=='K' and sm[i+1]=='r':
|
128 |
+
arr.append(sm[i:i+2])
|
129 |
+
i += 2
|
130 |
+
elif sm[i]=='F' and sm[i+1]=='e':
|
131 |
+
arr.append(sm[i:i+2])
|
132 |
+
i += 2
|
133 |
+
else:
|
134 |
+
arr.append(sm[i])
|
135 |
+
i += 1
|
136 |
+
if i == len(sm)-1:
|
137 |
+
arr.append(sm[i])
|
138 |
+
return ' '.join(arr)
|
139 |
+
|
140 |
+
# 活性化関数
|
141 |
+
class GELU(nn.Module):
|
142 |
+
def forward(self, x):
|
143 |
+
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
144 |
+
|
145 |
+
# 位置情報を考慮したFFN
|
146 |
+
class PositionwiseFeedForward(nn.Module):
|
147 |
+
def __init__(self, d_model, d_ff, dropout=0.1):
|
148 |
+
super(PositionwiseFeedForward, self).__init__()
|
149 |
+
self.w_1 = nn.Linear(d_model, d_ff)
|
150 |
+
self.w_2 = nn.Linear(d_ff, d_model)
|
151 |
+
self.dropout = nn.Dropout(dropout)
|
152 |
+
self.activation = GELU()
|
153 |
+
|
154 |
+
def forward(self, x):
|
155 |
+
return self.w_2(self.dropout(self.activation(self.w_1(x))))
|
156 |
+
|
157 |
+
# 正規化層
|
158 |
+
class LayerNorm(nn.Module):
|
159 |
+
def __init__(self, features, eps=1e-6):
|
160 |
+
super(LayerNorm, self).__init__()
|
161 |
+
self.a_2 = nn.Parameter(torch.ones(features))
|
162 |
+
self.b_2 = nn.Parameter(torch.zeros(features))
|
163 |
+
self.eps = eps
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
mean = x.mean(-1, keepdim=True)
|
167 |
+
std = x.std(-1, keepdim=True)
|
168 |
+
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
169 |
+
|
170 |
+
|
171 |
+
class SublayerConnection(nn.Module):
|
172 |
+
def __init__(self, size, dropout):
|
173 |
+
super(SublayerConnection, self).__init__()
|
174 |
+
self.norm = LayerNorm(size)
|
175 |
+
self.dropout = nn.Dropout(dropout)
|
176 |
+
|
177 |
+
def forward(self, x, sublayer):
|
178 |
+
return x + self.dropout(sublayer(self.norm(x)))
|
179 |
+
|
180 |
+
# Sample SMILES from probablistic distribution
|
181 |
+
def sample(msms):
|
182 |
+
ret = []
|
183 |
+
for msm in msms:
|
184 |
+
ret.append(torch.multinomial(msm.exp(), 1).squeeze())
|
185 |
+
return torch.stack(ret)
|
186 |
+
|
187 |
+
def validity(smiles):
|
188 |
+
loss = 0
|
189 |
+
for sm in smiles:
|
190 |
+
mol = Chem.MolFromSmiles(sm)
|
191 |
+
if mol is None:
|
192 |
+
loss += 1
|
193 |
+
return 1-loss/len(smiles)
|
194 |
+
|
vocab.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:21a66c850a3222547ec0fbd30c05fe587d66d22d3de2ee2195c58250fe486fb7
|
3 |
+
size 1446
|
vocab_content.txt
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<pad>
|
2 |
+
<unk>
|
3 |
+
<eos>
|
4 |
+
<sos>
|
5 |
+
<mask>
|
6 |
+
c
|
7 |
+
C
|
8 |
+
(
|
9 |
+
)
|
10 |
+
O
|
11 |
+
=
|
12 |
+
1
|
13 |
+
N
|
14 |
+
2
|
15 |
+
3
|
16 |
+
n
|
17 |
+
4
|
18 |
+
@
|
19 |
+
[
|
20 |
+
]
|
21 |
+
H
|
22 |
+
F
|
23 |
+
5
|
24 |
+
S
|
25 |
+
\
|
26 |
+
Cl
|
27 |
+
s
|
28 |
+
6
|
29 |
+
o
|
30 |
+
+
|
31 |
+
-
|
32 |
+
#
|
33 |
+
/
|
34 |
+
.
|
35 |
+
Br
|
36 |
+
7
|
37 |
+
P
|
38 |
+
I
|
39 |
+
8
|
40 |
+
Na
|
41 |
+
B
|
42 |
+
Si
|
43 |
+
Se
|
44 |
+
9
|
45 |
+
K
|