Spaces:
Running
Running
# Deep learning | |
import torch | |
# Data | |
from pubchem_encoder import Encoder | |
from datasets import load_dataset | |
# Standard library | |
import os | |
import getpass | |
import glob | |
class MoleculeModule: | |
def __init__(self, max_len, dataset, data_path): | |
super().__init__() | |
self.dataset = dataset | |
self.data_path = data_path | |
self.text_encoder = Encoder(max_len) | |
def prepare_data(self): | |
pass | |
def get_vocab(self): | |
#using home made tokenizer, should look into existing tokenizer | |
return self.text_encoder.char2id | |
def get_cache(self): | |
return self.cache_files | |
def setup(self, stage=None): | |
#using huggingface dataloader | |
# create cache in tmp directory of locale mabchine under the current users name to prevent locking issues | |
pubchem_path = {'train': self.data_path} | |
if 'canonical' in pubchem_path['train'].lower(): | |
pubchem_script = './pubchem_canon_script.py' | |
else: | |
pubchem_script = './pubchem_script.py' | |
zinc_path = './data/ZINC' | |
global dataset_dict | |
if 'ZINC' in self.dataset or 'zinc' in self.dataset: | |
zinc_files = [f for f in glob.glob(os.path.join(zinc_path,'*.smi'))] | |
for zfile in zinc_files: | |
print(zfile) | |
self.dataset = {'train': zinc_files} | |
dataset_dict = load_dataset('./zinc_script.py', data_files=self.dataset, cache_dir=os.path.join('/tmp',getpass.getuser(), 'zinc'),split='train', trust_remote_code=True) | |
elif 'pubchem' in self.dataset: | |
dataset_dict = load_dataset(pubchem_script, data_files=pubchem_path, cache_dir=os.path.join('/tmp',getpass.getuser(), 'pubchem'), split='train') | |
elif 'both' in self.dataset or 'Both' in self.dataset or 'BOTH' in self.dataset: | |
dataset_dict_pubchem = load_dataset(pubchem_script, data_files=pubchem_path, cache_dir=os.path.join('/tmp',getpass.getuser(), 'pubchem'),split='train', trust_remote_code=True) | |
zinc_files = [f for f in glob.glob(os.path.join(zinc_path,'*.smi'))] | |
for zfile in zinc_files: | |
print(zfile) | |
self.dataset = {'train': zinc_files} | |
dataset_dict_zinc = load_dataset('./zinc_script.py', data_files=self.dataset, cache_dir=os.path.join('/tmp',getpass.getuser(), 'zinc'),split='train', trust_remote_code=True) | |
dataset_dict = concatenate_datasets([dataset_dict_zinc, dataset_dict_pubchem]) | |
self.pubchem= dataset_dict | |
print(dataset_dict.cache_files) | |
self.cache_files = [] | |
for cache in dataset_dict.cache_files: | |
tmp = '/'.join(cache['filename'].split('/')[:4]) | |
self.cache_files.append(tmp) | |
def get_optim_groups(module): | |
# setup optimizer | |
# separate out all parameters to those that will and won't experience regularizing weight decay | |
decay = set() | |
no_decay = set() | |
whitelist_weight_modules = (torch.nn.Linear,) | |
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) | |
for mn, m in module.named_modules(): | |
for pn, p in m.named_parameters(): | |
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name | |
if pn.endswith('bias'): | |
# all biases will not be decayed | |
no_decay.add(fpn) | |
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): | |
# weights of whitelist modules will be weight decayed | |
decay.add(fpn) | |
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): | |
# weights of blacklist modules will NOT be weight decayed | |
no_decay.add(fpn) | |
# validate that we considered every parameter | |
param_dict = {pn: p for pn, p in module.named_parameters()} | |
# create the pytorch optimizer object | |
optim_groups = [ | |
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.0}, | |
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, | |
] | |
return optim_groups |