File size: 5,065 Bytes
224a33f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import sys
import os
sys.path.append(f'{os.getcwd()}/esm3')
import torch
from utils.util_functions import read_pdb, extract_sequence, get_features
from Bio.PDB import PDBParser
from hadder import AddHydrogen
import numpy as np
from transformers import AutoTokenizer, AutoModel
from utils.get_property_embs import get_ss8_dim9, get_dihedrals_dim16, get_atom_features_dim7, get_hbond_features_dim2, get_pef_features_dim1, get_residue_features_dim27
from torch_geometric.data import Data
from huggingface_hub import login
from esm.models.esm3 import ESM3
from esm.utils.constants.models import ESM3_OPEN_SMALL
from esm.sdk.api import ESMProtein, SamplingConfig

def model_predict(model, pdb_file, function):
    function = 'Unknown' if function == '' else function

    model_path = f'pretrained/{model.lower()}.pth'
    if model == 'M3Site-ESM3-abs':
        plm_path = ESM3_OPEN_SMALL
        blm_path = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'
    elif model == 'M3Site-ESM3-full':
        plm_path = ESM3_OPEN_SMALL
        blm_path = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext'
    elif model == 'M3Site-ESM2-abs':
        plm_path = 'facebook/esm2_t33_650M_UR50D'
        blm_path = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'
    elif model == 'M3Site-ESM2-full':
        plm_path = 'facebook/esm2_t33_650M_UR50D'
        blm_path = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext'
    elif model == 'M3Site-ESM1b-abs':
        plm_path = 'facebook/esm1b_t33_650M_UR50S'
        blm_path = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract'
    elif model == 'M3Site-ESM1b-full':
        plm_path = 'facebook/esm1b_t33_650M_UR50S'
        blm_path = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext'

    login(token=os.environ.get("ESM3TOKEN"))

    text_tokenizer = AutoTokenizer.from_pretrained(blm_path)
    text_model = AutoModel.from_pretrained(blm_path)

    model = torch.load(model_path, map_location='cpu')
    if 'esm3' not in plm_path:
        seq_tokenizer = AutoTokenizer.from_pretrained(plm_path)
        seq_model = AutoModel.from_pretrained(plm_path)
    else:
        seq_model = ESM3.from_pretrained(plm_path)

    # 得到structure
    structure = read_pdb(pdb_file)

    # 得到prop
    parser = PDBParser(QUIET=True)
    pdb_file_addH = pdb_file.split('.')[0] + '_addH.pdb'
    AddHydrogen(pdb_file, pdb_file_addH)
    struct = parser.get_structure('protein', pdb_file_addH)
    ss8 = get_ss8_dim9(struct, pdb_file_addH)
    angles_matrix = get_dihedrals_dim16(struct, pdb_file_addH)
    atom_feature = get_atom_features_dim7(struct)
    hbond_feature = get_hbond_features_dim2(pdb_file_addH)
    pef_feature = get_pef_features_dim1(struct)
    residue_feature = get_residue_features_dim27(struct)
    prop = np.concatenate((ss8, angles_matrix, atom_feature, hbond_feature, pef_feature, residue_feature), axis=1)
    os.remove(pdb_file_addH)

    # 提取三维信息
    alpha_carbons = [atom for atom in structure.get_atoms() if atom.get_id() == 'CA']
    positions = [atom.coord for atom in alpha_carbons]
    atom_indices = list(range(len(alpha_carbons)))

    # 获取结点特征
    sequence = extract_sequence(structure)
    assert len(sequence) == len(alpha_carbons)
    if 'esm3' not in plm_path:
        node_features = get_features(sequence, seq_tokenizer, seq_model)
    else:
        protein = ESMProtein(sequence=sequence)
        protein_tensor = seq_model.encode(protein)
        output = seq_model.forward_and_sample(protein_tensor, SamplingConfig(return_per_residue_embeddings=True))
        node_features = output.per_residue_embedding[1:-1]

    # 构建边
    edges, edge_attrs = [], []
    for i, atom1 in enumerate(alpha_carbons):
        for j, atom2 in enumerate(alpha_carbons):
            if i < j:
                distance = np.linalg.norm(atom1.coord - atom2.coord)
                if distance < 8.0:
                    edges.append((i, j))
                    edge_attrs.append(distance)
    edge_index = torch.tensor(np.array(edges), dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(np.array(edge_attrs), dtype=torch.float)

    # 处理文本
    func = get_features(function, text_tokenizer, text_model, modal='text')

    # 构建Data对象
    data = Data(x=torch.tensor(np.array(atom_indices), dtype=torch.long).unsqueeze(1),
                edge_index=edge_index,
                edge_attr=edge_attr.unsqueeze(1),
                esm_rep=node_features,
                prop=torch.tensor(prop, dtype=torch.float),
                pos=torch.tensor(np.array(positions), dtype=torch.float),
                func=func)

    model_output = model(data)
    output = model_output.argmax(dim=-1).numpy()
    confs = torch.max(model_output, dim=-1)[0].detach().numpy()

    res = {'0':[], '1':[], '2':[], '3':[], '4':[], '5':[]}
    for i in range(len(output)):
        if output[i] != 0:
            res[str(output[i]-1)].append(i+1)   # 返回的是从1开始编号的
            
    return res, confs, sequence