M3Site / predict.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
import sys
import os
sys.path.append(f'{os.getcwd()}/esm3')
import torch
from utils.util_functions import read_pdb, extract_sequence, get_features
from Bio.PDB import PDBParser
from hadder import AddHydrogen
import numpy as np
from transformers import AutoTokenizer, AutoModel
from utils.get_property_embs import get_ss8_dim9, get_dihedrals_dim16, get_atom_features_dim7, get_hbond_features_dim2, get_pef_features_dim1, get_residue_features_dim27
from torch_geometric.data import Data
from huggingface_hub import login
from esm.models.esm3 import ESM3
from esm.utils.constants.models import ESM3_OPEN_SMALL
from esm.sdk.api import ESMProtein, SamplingConfig
def model_predict(model, pdb_file, function):
function = 'Unknown' if function == '' else function
model_path = f'pretrained/{model.lower()}.pth'
if model == 'M3Site-ESM3-abs':
plm_path = ESM3_OPEN_SMALL
blm_path = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'
elif model == 'M3Site-ESM3-full':
plm_path = ESM3_OPEN_SMALL
blm_path = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext'
elif model == 'M3Site-ESM2-abs':
plm_path = 'facebook/esm2_t33_650M_UR50D'
blm_path = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'
elif model == 'M3Site-ESM2-full':
plm_path = 'facebook/esm2_t33_650M_UR50D'
blm_path = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext'
elif model == 'M3Site-ESM1b-abs':
plm_path = 'facebook/esm1b_t33_650M_UR50S'
blm_path = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'
elif model == 'M3Site-ESM1b-full':
plm_path = 'facebook/esm1b_t33_650M_UR50S'
blm_path = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext'
login(token=os.environ.get("ESM3TOKEN"))
text_tokenizer = AutoTokenizer.from_pretrained(blm_path)
text_model = AutoModel.from_pretrained(blm_path)
model = torch.load(model_path, map_location='cpu')
if 'esm3' not in plm_path:
seq_tokenizer = AutoTokenizer.from_pretrained(plm_path)
seq_model = AutoModel.from_pretrained(plm_path)
else:
seq_model = ESM3.from_pretrained(plm_path)
# 得到structure
structure = read_pdb(pdb_file)
# 得到prop
parser = PDBParser(QUIET=True)
pdb_file_addH = pdb_file.split('.')[0] + '_addH.pdb'
AddHydrogen(pdb_file, pdb_file_addH)
struct = parser.get_structure('protein', pdb_file_addH)
ss8 = get_ss8_dim9(struct, pdb_file_addH)
angles_matrix = get_dihedrals_dim16(struct, pdb_file_addH)
atom_feature = get_atom_features_dim7(struct)
hbond_feature = get_hbond_features_dim2(pdb_file_addH)
pef_feature = get_pef_features_dim1(struct)
residue_feature = get_residue_features_dim27(struct)
prop = np.concatenate((ss8, angles_matrix, atom_feature, hbond_feature, pef_feature, residue_feature), axis=1)
os.remove(pdb_file_addH)
# 提取三维信息
alpha_carbons = [atom for atom in structure.get_atoms() if atom.get_id() == 'CA']
positions = [atom.coord for atom in alpha_carbons]
atom_indices = list(range(len(alpha_carbons)))
# 获取结点特征
sequence = extract_sequence(structure)
assert len(sequence) == len(alpha_carbons)
if 'esm3' not in plm_path:
node_features = get_features(sequence, seq_tokenizer, seq_model)
else:
protein = ESMProtein(sequence=sequence)
protein_tensor = seq_model.encode(protein)
output = seq_model.forward_and_sample(protein_tensor, SamplingConfig(return_per_residue_embeddings=True))
node_features = output.per_residue_embedding[1:-1]
# 构建边
edges, edge_attrs = [], []
for i, atom1 in enumerate(alpha_carbons):
for j, atom2 in enumerate(alpha_carbons):
if i < j:
distance = np.linalg.norm(atom1.coord - atom2.coord)
if distance < 8.0:
edges.append((i, j))
edge_attrs.append(distance)
edge_index = torch.tensor(np.array(edges), dtype=torch.long).t().contiguous()
edge_attr = torch.tensor(np.array(edge_attrs), dtype=torch.float)
# 处理文本
func = get_features(function, text_tokenizer, text_model, modal='text')
# 构建Data对象
data = Data(x=torch.tensor(np.array(atom_indices), dtype=torch.long).unsqueeze(1),
edge_index=edge_index,
edge_attr=edge_attr.unsqueeze(1),
esm_rep=node_features,
prop=torch.tensor(prop, dtype=torch.float),
pos=torch.tensor(np.array(positions), dtype=torch.float),
func=func)
model_output = model(data)
output = model_output.argmax(dim=-1).numpy()
confs = torch.max(model_output, dim=-1)[0].detach().numpy()
res = {'0':[], '1':[], '2':[], '3':[], '4':[], '5':[]}
for i in range(len(output)):
if output[i] != 0:
res[str(output[i]-1)].append(i+1) # 返回的是从1开始编号的
return res, confs, sequence