File size: 5,597 Bytes
224a33f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import BatchNorm
from utils.util_classes import CenterLoss, InterClassLoss
from model.egnn.network import EGNN


class AP_align_fuse_graph(torch.nn.Module):

    def __init__(self, config, hidden_size=256):
        super(AP_align_fuse_graph, self).__init__()
        self.config = config
        self.seq_max_length = config.dataset.seq_max_length
        if '3' in config.dataset.lm:
            self.embedding_dim = 1536
        elif 't5' in config.dataset.lm:
            self.embedding_dim = 1024
        else:
            self.embedding_dim = 1280

        self.egnn_model = EGNN(config)
        self.egnn_out_dim = self.config.egnn.output_dim
        self.num_classes = 7
        self.fc1 = nn.Linear(self.embedding_dim+self.egnn_out_dim, hidden_size)
        self.bn1 = BatchNorm(hidden_size)
        self.fc2 = nn.Linear(hidden_size, self.num_classes)
        self.funicross1 = FunICross(self.egnn_out_dim, self.embedding_dim, condition_dim=768)
        self.funicross2 = FunICross(self.embedding_dim, self.egnn_out_dim, condition_dim=768)
        self.weight_fc = nn.Linear((self.embedding_dim+self.egnn_out_dim) * 2, 1)
        self.center_loss = CenterLoss(num_classes=self.num_classes, feat_dim=7)
        self.inter_loss = InterClassLoss(margin=0.1)
        self.ab_egnn = nn.Linear(self.egnn_out_dim, self.embedding_dim+self.egnn_out_dim)
        self.ab_esm = nn.Linear(self.embedding_dim, self.embedding_dim+self.egnn_out_dim)

    def forward(self, data):
        esm_rep, batch, func = data.esm_rep, data.batch, data.func
        graphs = 1
        egnn_output = self.egnn_model(data) # [nodes, 16]

        # esm_data = torch.zeros(graphs, 1024, 1280).to(esm_rep.device)   # [1, 1024, 1280]
        # egnn_data = torch.zeros(graphs, 1024, self.egnn_out_dim).to(egnn_output.device) # [1, 1024, 16]
        func_data = func.reshape(graphs, 768)   # [1, 768]
        # for graph_idx in range(graphs):
        #     mask = (batch == graph_idx)
        #     esm_data[graph_idx][:esm_rep[mask].shape[0]] = esm_rep[mask]
        #     egnn_data[graph_idx][:egnn_output[mask].shape[0]] = egnn_output[mask]
        esm_data = F.pad(esm_rep, (0, 0, 0, 1024-esm_rep.shape[0]), value=0).unsqueeze(0)
        egnn_data = F.pad(egnn_output, (0, 0, 0, 1024-egnn_output.shape[0]), value=0).unsqueeze(0)

        total = torch.cat([esm_data, egnn_data], dim=-1)   # [graphs, 1024, 1280+16]
        stru_seq_seq = self.funicross1(egnn_data, esm_data, esm_data, func_data)   # [graphs, 1024, 16]
        seq_stru_stru = self.funicross2(esm_data, egnn_data, egnn_data, func_data)  # [graphs, 1024, 1280]
        fusion_out = torch.cat([stru_seq_seq, seq_stru_stru], dim=-1) # [graphs, 1024, 1280+16]

        combined = torch.cat([fusion_out, total], dim=-1)
        weight = torch.sigmoid(self.weight_fc(combined))
        out = weight * fusion_out + (1 - weight) * total

        out = self.fc1(out).permute(0, 2, 1)
        out = self.bn1(out).permute(0, 2, 1)
        out = torch.relu(out)
        out = self.fc2(out)
        
        recon_out = out[0][:esm_rep.shape[0]]
        recon_out = torch.softmax(recon_out, dim=-1)

        return recon_out


class CrossAttention(nn.Module):
    def __init__(self, dim1, dim2, dropout=0.1):
        super(CrossAttention, self).__init__()
        self.dim1 = dim1
        self.key = nn.Linear(dim2, dim1)
        self.value = nn.Linear(dim2, dim1)
        self.out = nn.Linear(dim1, dim1)

    def forward(self, Q, K, V):
        Q_proj = Q
        K_proj = self.key(K)       # [len, dim1]
        V_proj = self.value(V)     # [len, dim1]
        attention_scores = torch.matmul(Q_proj, K_proj.transpose(-2, -1))  # [len, len] # 由于是分块矩阵,所以可以直接相乘
        attention_scores = attention_scores / (self.dim1 ** 0.5)  # Scale by the square root of dim1
        attention_probs = F.softmax(attention_scores, dim=-1)  # Softmax over the last dimension (keys)
        context = torch.matmul(attention_probs, V_proj)  # [len, dim1]
        output = self.out(context)  # [len, dim1]
        return output


class FeedForward(nn.Module):
    def __init__(self, dim, ff_dim=128, dropout=0.1, condition_dim=None):
        super(FeedForward, self).__init__()
        input_dim = dim + condition_dim if condition_dim is not None else dim
        self.fc1 = nn.Linear(input_dim, ff_dim)
        self.fc2 = nn.Linear(ff_dim, dim)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()

    def forward(self, x, condition=None):
        if condition is not None:
            condition = condition.unsqueeze(1).expand(-1, x.size(1), -1)  # [len, condition_dim]
            x = torch.cat([x, condition], dim=-1)  # [len, dim + condition_dim]
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x


class FunICross(nn.Module):
    def __init__(self, dim1, dim2, ff_dim=128, dropout=0.1, condition_dim=None):
        super(FunICross, self).__init__()
        self.attn = CrossAttention(dim1, dim2, dropout)
        self.attn_layer_norm = nn.LayerNorm(dim1)
        self.ff = FeedForward(dim1, ff_dim, dropout, condition_dim)
        self.ff_layer_norm = nn.LayerNorm(dim1)

    def forward(self, Q, K, V, condition=None):
        attn_output = self.attn(Q, K, V)
        Q = self.attn_layer_norm(Q + attn_output)
        ff_output = self.ff(Q, condition)   # 把condition加到了feedforward的输入中
        Q = self.ff_layer_norm(Q + ff_output)
        return Q