import pandas as pd import os import torch from transformers import RobertaTokenizerFast, RobertaForMaskedLM, DataCollatorWithPadding import datasets from datasets import disable_caching disable_caching() DEVICE = 'cuda:0' # model device ENCODER_MODEL_NAME = "entropy/roberta_zinc_480m" # encoder name ENCODER_BATCH_SIZE = 1024 # batch size for computing embeddings TOKENIZER_MAX_LEN = 256 # max_length param on tokenizer TOKENIZATION_NUM_PROC = 32 # number of processes for tokenization ''' Data source is expected to be a CSV file with a column of SMILES strings denoted by `SMILES_COLUMN`. The CSV is processed in chunks of size `PROCESS_CHUNKSIZE`. Processed chunks are saved to `SAVE_PATH` with the format `SAVE_PATH/processed_shard_{i}.hf` ''' DATASET_CSV_FILENAME = None # path to data csv PROCESS_CHUNKSIZE = 1000000 # how many rows to process/save for each dataset shard SMILES_COLUMN = 'smiles' # csv column holding smiles strings MAX_CHUNKS = None # total number of chunks to process (if None, all chunks are processed) MAX_SMILES_LENGTH = 90 # max smiles string length (exclusive) MIN_SMILES_LENGTH = 5 # min smiles string length (exclusive) FILTER_NUM_PROC = 32 # number of processes for filtering SAVE_PATH = None # directory to save data shards to assert DATASET_CSV_FILENAME is not None, "must specify dataset filename" assert SAVE_PATH is not None, "must specify save path" def tokenization(example): return tokenizer(example[SMILES_COLUMN], add_special_tokens=True, truncation=True, max_length=TOKENIZER_MAX_LEN) def embed(inputs): inputs = {k:inputs[k] for k in ['input_ids', 'attention_mask']} inputs = collator(inputs) inputs = {k:v.to(DEVICE) for k,v in inputs.items()} with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True) full_embeddings = outputs[-1][-1] mask = inputs['attention_mask'] mean_embeddings = ((full_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(-1).unsqueeze(-1)) return {'encoder_hidden_states' : mean_embeddings} def length_filter_smiles(example): min_check = (len(example[SMILES_COLUMN])>MIN_SMILES_LENGTH) if (MIN_SMILES_LENGTH is not None) else True max_check = (len(example[SMILES_COLUMN])= MAX_CHUNKS-1): break print('finished data processing')