import binascii import glob import hashlib import os import pickle from collections import defaultdict from multiprocessing import Pool import random import copy import numpy as np import torch from rdkit.Chem import MolToSmiles, MolFromSmiles, AddHs from torch_geometric.data import Dataset, HeteroData from torch_geometric.loader import DataLoader, DataListLoader from torch_geometric.transforms import BaseTransform from tqdm import tqdm from datasets.process_mols import ( read_molecule, get_rec_graph, generate_conformer, get_lig_graph_with_matching, extract_receptor_structure, parse_receptor, parse_pdb_from_path, ) from utils.diffusion_utils import modify_conformer, set_time from utils.utils import read_strings_from_txt from utils import so3, torus class NoiseTransform(BaseTransform): def __init__(self, t_to_sigma, no_torsion, all_atom): self.t_to_sigma = t_to_sigma self.no_torsion = no_torsion self.all_atom = all_atom def __call__(self, data): t = np.random.uniform() t_tr, t_rot, t_tor = t, t, t return self.apply_noise(data, t_tr, t_rot, t_tor) def apply_noise( self, data, t_tr, t_rot, t_tor, tr_update=None, rot_update=None, torsion_updates=None, ): if not torch.is_tensor(data["ligand"].pos): data["ligand"].pos = random.choice(data["ligand"].pos) tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(t_tr, t_rot, t_tor) set_time(data, t_tr, t_rot, t_tor, 1, self.all_atom, device=None) tr_update = ( torch.normal(mean=0, std=tr_sigma, size=(1, 3)) if tr_update is None else tr_update ) rot_update = so3.sample_vec(eps=rot_sigma) if rot_update is None else rot_update torsion_updates = ( np.random.normal( loc=0.0, scale=tor_sigma, size=data["ligand"].edge_mask.sum() ) if torsion_updates is None else torsion_updates ) torsion_updates = None if self.no_torsion else torsion_updates modify_conformer( data, tr_update, torch.from_numpy(rot_update).float(), torsion_updates ) data.tr_score = -tr_update / tr_sigma**2 data.rot_score = ( torch.from_numpy(so3.score_vec(vec=rot_update, eps=rot_sigma)) .float() .unsqueeze(0) ) data.tor_score = ( None if self.no_torsion else torch.from_numpy(torus.score(torsion_updates, tor_sigma)).float() ) data.tor_sigma_edge = ( None if self.no_torsion else np.ones(data["ligand"].edge_mask.sum()) * tor_sigma ) return data class PDBBind(Dataset): def __init__( self, root, transform=None, cache_path="data/cache", split_path="data/", limit_complexes=0, receptor_radius=30, num_workers=1, c_alpha_max_neighbors=None, popsize=15, maxiter=15, matching=True, keep_original=False, max_lig_size=None, remove_hs=False, num_conformers=1, all_atoms=False, atom_radius=5, atom_max_neighbors=None, esm_embeddings_path=None, require_ligand=False, ligands_list=None, protein_path_list=None, ligand_descriptions=None, keep_local_structures=False, ): super(PDBBind, self).__init__(root, transform) self.pdbbind_dir = root self.max_lig_size = max_lig_size self.split_path = split_path self.limit_complexes = limit_complexes self.receptor_radius = receptor_radius self.num_workers = num_workers self.c_alpha_max_neighbors = c_alpha_max_neighbors self.remove_hs = remove_hs self.esm_embeddings_path = esm_embeddings_path self.require_ligand = require_ligand self.protein_path_list = protein_path_list self.ligand_descriptions = ligand_descriptions self.keep_local_structures = keep_local_structures if ( matching or protein_path_list is not None and ligand_descriptions is not None ): cache_path += "_torsion" if all_atoms: cache_path += "_allatoms" self.full_cache_path = os.path.join( cache_path, f"limit{self.limit_complexes}" f"_INDEX{os.path.splitext(os.path.basename(self.split_path))[0]}" f"_maxLigSize{self.max_lig_size}_H{int(not self.remove_hs)}" f"_recRad{self.receptor_radius}_recMax{self.c_alpha_max_neighbors}" + ( "" if not all_atoms else f"_atomRad{atom_radius}_atomMax{atom_max_neighbors}" ) + ("" if not matching or num_conformers == 1 else f"_confs{num_conformers}") + ("" if self.esm_embeddings_path is None else f"_esmEmbeddings") + ("" if not keep_local_structures else f"_keptLocalStruct") + ( "" if protein_path_list is None or ligand_descriptions is None else str( binascii.crc32( "".join(ligand_descriptions + protein_path_list).encode() ) ) ), ) self.popsize, self.maxiter = popsize, maxiter self.matching, self.keep_original = matching, keep_original self.num_conformers = num_conformers self.all_atoms = all_atoms self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors if not os.path.exists( os.path.join(self.full_cache_path, "heterographs.pkl") ) or ( require_ligand and not os.path.exists( os.path.join(self.full_cache_path, "rdkit_ligands.pkl") ) ): os.makedirs(self.full_cache_path, exist_ok=True) if protein_path_list is None or ligand_descriptions is None: self.preprocessing() else: self.inference_preprocessing() print( "loading data from memory: ", os.path.join(self.full_cache_path, "heterographs.pkl"), ) with open(os.path.join(self.full_cache_path, "heterographs.pkl"), "rb") as f: self.complex_graphs = pickle.load(f) if require_ligand: with open( os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), "rb" ) as f: self.rdkit_ligands = pickle.load(f) print_statistics(self.complex_graphs) def len(self): return len(self.complex_graphs) def get(self, idx): if self.require_ligand: complex_graph = copy.deepcopy(self.complex_graphs[idx]) complex_graph.mol = copy.deepcopy(self.rdkit_ligands[idx]) return complex_graph else: return copy.deepcopy(self.complex_graphs[idx]) def preprocessing(self): print( f"Processing complexes from [{self.split_path}] and saving it to [{self.full_cache_path}]" ) complex_names_all = read_strings_from_txt(self.split_path) if self.limit_complexes is not None and self.limit_complexes != 0: complex_names_all = complex_names_all[: self.limit_complexes] print(f"Loading {len(complex_names_all)} complexes.") if self.esm_embeddings_path is not None: id_to_embeddings = torch.load(self.esm_embeddings_path) chain_embeddings_dictlist = defaultdict(list) for key, embedding in id_to_embeddings.items(): key_name = key.split("_")[0] if key_name in complex_names_all: chain_embeddings_dictlist[key_name].append(embedding) lm_embeddings_chains_all = [] for name in complex_names_all: lm_embeddings_chains_all.append(chain_embeddings_dictlist[name]) else: lm_embeddings_chains_all = [None] * len(complex_names_all) if self.num_workers > 1: # running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes for i in range(len(complex_names_all) // 1000 + 1): if os.path.exists( os.path.join(self.full_cache_path, f"heterographs{i}.pkl") ): continue complex_names = complex_names_all[1000 * i : 1000 * (i + 1)] lm_embeddings_chains = lm_embeddings_chains_all[ 1000 * i : 1000 * (i + 1) ] complex_graphs, rdkit_ligands = [], [] if self.num_workers > 1: p = Pool(self.num_workers, maxtasksperchild=1) p.__enter__() with tqdm( total=len(complex_names), desc=f"loading complexes {i}/{len(complex_names_all)//1000+1}", ) as pbar: map_fn = p.imap_unordered if self.num_workers > 1 else map for t in map_fn( self.get_complex, zip( complex_names, lm_embeddings_chains, [None] * len(complex_names), [None] * len(complex_names), ), ): complex_graphs.extend(t[0]) rdkit_ligands.extend(t[1]) pbar.update() if self.num_workers > 1: p.__exit__(None, None, None) with open( os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), "wb" ) as f: pickle.dump((complex_graphs), f) with open( os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), "wb" ) as f: pickle.dump((rdkit_ligands), f) complex_graphs_all = [] for i in range(len(complex_names_all) // 1000 + 1): with open( os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), "rb" ) as f: l = pickle.load(f) complex_graphs_all.extend(l) with open( os.path.join(self.full_cache_path, f"heterographs.pkl"), "wb" ) as f: pickle.dump((complex_graphs_all), f) rdkit_ligands_all = [] for i in range(len(complex_names_all) // 1000 + 1): with open( os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), "rb" ) as f: l = pickle.load(f) rdkit_ligands_all.extend(l) with open( os.path.join(self.full_cache_path, f"rdkit_ligands.pkl"), "wb" ) as f: pickle.dump((rdkit_ligands_all), f) else: complex_graphs, rdkit_ligands = [], [] with tqdm(total=len(complex_names_all), desc="loading complexes") as pbar: for t in map( self.get_complex, zip( complex_names_all, lm_embeddings_chains_all, [None] * len(complex_names_all), [None] * len(complex_names_all), ), ): complex_graphs.extend(t[0]) rdkit_ligands.extend(t[1]) pbar.update() with open( os.path.join(self.full_cache_path, "heterographs.pkl"), "wb" ) as f: pickle.dump((complex_graphs), f) with open( os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), "wb" ) as f: pickle.dump((rdkit_ligands), f) def inference_preprocessing(self): ligands_list = [] print("Reading molecules and generating local structures with RDKit") for ligand_description in tqdm(self.ligand_descriptions): mol = MolFromSmiles(ligand_description) # check if it is a smiles or a path print(ligand_description, mol) if mol is not None: mol = AddHs(mol) generate_conformer(mol) ligands_list.append(mol) else: mol = read_molecule(ligand_description, remove_hs=False, sanitize=True) print(mol) if not self.keep_local_structures: mol.RemoveAllConformers() mol = AddHs(mol) generate_conformer(mol) ligands_list.append(mol) if self.esm_embeddings_path is not None: print("Reading language model embeddings.") lm_embeddings_chains_all = [] if not os.path.exists(self.esm_embeddings_path): raise Exception( "ESM embeddings path does not exist: ", self.esm_embeddings_path ) for protein_path in self.protein_path_list: embeddings_paths = sorted( glob.glob( os.path.join( self.esm_embeddings_path, os.path.basename(protein_path) ) + "*" ) ) lm_embeddings_chains = [] for embeddings_path in embeddings_paths: lm_embeddings_chains.append( torch.load(embeddings_path)["representations"][33] ) lm_embeddings_chains_all.append(lm_embeddings_chains) else: lm_embeddings_chains_all = [None] * len(self.protein_path_list) print("Generating graphs for ligands and proteins") if self.num_workers > 1: # running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes for i in range(len(self.protein_path_list) // 1000 + 1): if os.path.exists( os.path.join(self.full_cache_path, f"heterographs{i}.pkl") ): continue protein_paths_chunk = self.protein_path_list[1000 * i : 1000 * (i + 1)] ligand_description_chunk = self.ligand_descriptions[ 1000 * i : 1000 * (i + 1) ] ligands_chunk = ligands_list[1000 * i : 1000 * (i + 1)] lm_embeddings_chains = lm_embeddings_chains_all[ 1000 * i : 1000 * (i + 1) ] complex_graphs, rdkit_ligands = [], [] if self.num_workers > 1: p = Pool(self.num_workers, maxtasksperchild=1) p.__enter__() with tqdm( total=len(protein_paths_chunk), desc=f"loading complexes {i}/{len(protein_paths_chunk)//1000+1}", ) as pbar: map_fn = p.imap_unordered if self.num_workers > 1 else map for t in map_fn( self.get_complex, zip( protein_paths_chunk, lm_embeddings_chains, ligands_chunk, ligand_description_chunk, ), ): complex_graphs.extend(t[0]) rdkit_ligands.extend(t[1]) pbar.update() if self.num_workers > 1: p.__exit__(None, None, None) with open( os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), "wb" ) as f: pickle.dump((complex_graphs), f) with open( os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), "wb" ) as f: pickle.dump((rdkit_ligands), f) complex_graphs_all = [] for i in range(len(self.protein_path_list) // 1000 + 1): with open( os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), "rb" ) as f: l = pickle.load(f) complex_graphs_all.extend(l) with open( os.path.join(self.full_cache_path, f"heterographs.pkl"), "wb" ) as f: pickle.dump((complex_graphs_all), f) rdkit_ligands_all = [] for i in range(len(self.protein_path_list) // 1000 + 1): with open( os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), "rb" ) as f: l = pickle.load(f) rdkit_ligands_all.extend(l) with open( os.path.join(self.full_cache_path, f"rdkit_ligands.pkl"), "wb" ) as f: pickle.dump((rdkit_ligands_all), f) else: complex_graphs, rdkit_ligands = [], [] with tqdm( total=len(self.protein_path_list), desc="loading complexes" ) as pbar: for t in map( self.get_complex, zip( self.protein_path_list, lm_embeddings_chains_all, ligands_list, self.ligand_descriptions, ), ): complex_graphs.extend(t[0]) rdkit_ligands.extend(t[1]) pbar.update() with open( os.path.join(self.full_cache_path, "heterographs.pkl"), "wb" ) as f: pickle.dump((complex_graphs), f) with open( os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), "wb" ) as f: pickle.dump((rdkit_ligands), f) def get_complex(self, par): name, lm_embedding_chains, ligand, ligand_description = par if not os.path.exists(os.path.join(self.pdbbind_dir, name)) and ligand is None: print("Folder not found", name) return [], [] if ligand is not None: rec_model = parse_pdb_from_path(name) name = f"{name}____{ligand_description}" ligs = [ligand] else: try: rec_model = parse_receptor(name, self.pdbbind_dir) except Exception as e: print(f"Skipping {name} because of the error:") print(e) return [], [] ligs = read_mols(self.pdbbind_dir, name, remove_hs=False) complex_graphs = [] for i, lig in enumerate(ligs): if ( self.max_lig_size is not None and lig.GetNumHeavyAtoms() > self.max_lig_size ): print( f"Ligand with {lig.GetNumHeavyAtoms()} heavy atoms is larger than max_lig_size {self.max_lig_size}. Not including {name} in preprocessed data." ) continue complex_graph = HeteroData() complex_graph["name"] = name try: get_lig_graph_with_matching( lig, complex_graph, self.popsize, self.maxiter, self.matching, self.keep_original, self.num_conformers, remove_hs=self.remove_hs, ) print(lm_embedding_chains) ( rec, rec_coords, c_alpha_coords, n_coords, c_coords, lm_embeddings, ) = extract_receptor_structure( copy.deepcopy(rec_model), lig, lm_embedding_chains=lm_embedding_chains, ) if lm_embeddings is not None and len(c_alpha_coords) != len( lm_embeddings ): print( f"LM embeddings for complex {name} did not have the right length for the protein. Skipping {name}." ) continue get_rec_graph( rec, rec_coords, c_alpha_coords, n_coords, c_coords, complex_graph, rec_radius=self.receptor_radius, c_alpha_max_neighbors=self.c_alpha_max_neighbors, all_atoms=self.all_atoms, atom_radius=self.atom_radius, atom_max_neighbors=self.atom_max_neighbors, remove_hs=self.remove_hs, lm_embeddings=lm_embeddings, ) except Exception as e: print(f"Skipping {name} because of the error:") print(e) raise e continue protein_center = torch.mean( complex_graph["receptor"].pos, dim=0, keepdim=True ) complex_graph["receptor"].pos -= protein_center if self.all_atoms: complex_graph["atom"].pos -= protein_center if (not self.matching) or self.num_conformers == 1: complex_graph["ligand"].pos -= protein_center else: for p in complex_graph["ligand"].pos: p -= protein_center complex_graph.original_center = protein_center complex_graphs.append(complex_graph) return complex_graphs, ligs def print_statistics(complex_graphs): statistics = ([], [], [], []) for complex_graph in complex_graphs: lig_pos = ( complex_graph["ligand"].pos if torch.is_tensor(complex_graph["ligand"].pos) else complex_graph["ligand"].pos[0] ) radius_protein = torch.max( torch.linalg.vector_norm(complex_graph["receptor"].pos, dim=1) ) molecule_center = torch.mean(lig_pos, dim=0) radius_molecule = torch.max( torch.linalg.vector_norm(lig_pos - molecule_center.unsqueeze(0), dim=1) ) distance_center = torch.linalg.vector_norm(molecule_center) statistics[0].append(radius_protein) statistics[1].append(radius_molecule) statistics[2].append(distance_center) if "rmsd_matching" in complex_graph: statistics[3].append(complex_graph.rmsd_matching) else: statistics[3].append(0) name = [ "radius protein", "radius molecule", "distance protein-mol", "rmsd matching", ] print("Number of complexes: ", len(complex_graphs)) for i in range(4): array = np.asarray(statistics[i]) print( f"{name[i]}: mean {np.mean(array)}, std {np.std(array)}, max {np.max(array)}" ) def construct_loader(args, t_to_sigma): transform = NoiseTransform( t_to_sigma=t_to_sigma, no_torsion=args.no_torsion, all_atom=args.all_atoms ) common_args = { "transform": transform, "root": args.data_dir, "limit_complexes": args.limit_complexes, "receptor_radius": args.receptor_radius, "c_alpha_max_neighbors": args.c_alpha_max_neighbors, "remove_hs": args.remove_hs, "max_lig_size": args.max_lig_size, "matching": not args.no_torsion, "popsize": args.matching_popsize, "maxiter": args.matching_maxiter, "num_workers": args.num_workers, "all_atoms": args.all_atoms, "atom_radius": args.atom_radius, "atom_max_neighbors": args.atom_max_neighbors, "esm_embeddings_path": args.esm_embeddings_path, } train_dataset = PDBBind( cache_path=args.cache_path, split_path=args.split_train, keep_original=True, num_conformers=args.num_conformers, **common_args, ) val_dataset = PDBBind( cache_path=args.cache_path, split_path=args.split_val, keep_original=True, **common_args, ) loader_class = DataListLoader if torch.cuda.is_available() else DataLoader train_loader = loader_class( dataset=train_dataset, batch_size=args.batch_size, num_workers=args.num_dataloader_workers, shuffle=True, pin_memory=args.pin_memory, ) val_loader = loader_class( dataset=val_dataset, batch_size=args.batch_size, num_workers=args.num_dataloader_workers, shuffle=True, pin_memory=args.pin_memory, ) return train_loader, val_loader def read_mol(pdbbind_dir, name, remove_hs=False): lig = read_molecule( os.path.join(pdbbind_dir, name, f"{name}_ligand.sdf"), remove_hs=remove_hs, sanitize=True, ) if lig is None: # read mol2 file if sdf file cannot be sanitized lig = read_molecule( os.path.join(pdbbind_dir, name, f"{name}_ligand.mol2"), remove_hs=remove_hs, sanitize=True, ) return lig def read_mols(pdbbind_dir, name, remove_hs=False): ligs = [] for file in os.listdir(os.path.join(pdbbind_dir, name)): if file.endswith(".sdf") and "rdkit" not in file: lig = read_molecule( os.path.join(pdbbind_dir, name, file), remove_hs=remove_hs, sanitize=True, ) if lig is None and os.path.exists( os.path.join(pdbbind_dir, name, file[:-4] + ".mol2") ): # read mol2 file if sdf file cannot be sanitized print( "Using the .sdf file failed. We found a .mol2 file instead and are trying to use that." ) lig = read_molecule( os.path.join(pdbbind_dir, name, file[:-4] + ".mol2"), remove_hs=remove_hs, sanitize=True, ) if lig is not None: ligs.append(lig) return ligs