import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import BatchNorm from utils.util_classes import CenterLoss, InterClassLoss from model.egnn.network import EGNN class AP_align_fuse_graph(torch.nn.Module): def __init__(self, config, hidden_size=256): super(AP_align_fuse_graph, self).__init__() self.config = config self.seq_max_length = config.dataset.seq_max_length if '3' in config.dataset.lm: self.embedding_dim = 1536 elif 't5' in config.dataset.lm: self.embedding_dim = 1024 else: self.embedding_dim = 1280 self.egnn_model = EGNN(config) self.egnn_out_dim = self.config.egnn.output_dim self.num_classes = 7 self.fc1 = nn.Linear(self.embedding_dim+self.egnn_out_dim, hidden_size) self.bn1 = BatchNorm(hidden_size) self.fc2 = nn.Linear(hidden_size, self.num_classes) self.funicross1 = FunICross(self.egnn_out_dim, self.embedding_dim, condition_dim=768) self.funicross2 = FunICross(self.embedding_dim, self.egnn_out_dim, condition_dim=768) self.weight_fc = nn.Linear((self.embedding_dim+self.egnn_out_dim) * 2, 1) self.center_loss = CenterLoss(num_classes=self.num_classes, feat_dim=7) self.inter_loss = InterClassLoss(margin=0.1) self.ab_egnn = nn.Linear(self.egnn_out_dim, self.embedding_dim+self.egnn_out_dim) self.ab_esm = nn.Linear(self.embedding_dim, self.embedding_dim+self.egnn_out_dim) def forward(self, data): esm_rep, batch, func = data.esm_rep, data.batch, data.func graphs = 1 egnn_output = self.egnn_model(data) # [nodes, 16] # esm_data = torch.zeros(graphs, 1024, 1280).to(esm_rep.device) # [1, 1024, 1280] # egnn_data = torch.zeros(graphs, 1024, self.egnn_out_dim).to(egnn_output.device) # [1, 1024, 16] func_data = func.reshape(graphs, 768) # [1, 768] # for graph_idx in range(graphs): # mask = (batch == graph_idx) # esm_data[graph_idx][:esm_rep[mask].shape[0]] = esm_rep[mask] # egnn_data[graph_idx][:egnn_output[mask].shape[0]] = egnn_output[mask] esm_data = F.pad(esm_rep, (0, 0, 0, 1024-esm_rep.shape[0]), value=0).unsqueeze(0) egnn_data = F.pad(egnn_output, (0, 0, 0, 1024-egnn_output.shape[0]), value=0).unsqueeze(0) total = torch.cat([esm_data, egnn_data], dim=-1) # [graphs, 1024, 1280+16] stru_seq_seq = self.funicross1(egnn_data, esm_data, esm_data, func_data) # [graphs, 1024, 16] seq_stru_stru = self.funicross2(esm_data, egnn_data, egnn_data, func_data) # [graphs, 1024, 1280] fusion_out = torch.cat([stru_seq_seq, seq_stru_stru], dim=-1) # [graphs, 1024, 1280+16] combined = torch.cat([fusion_out, total], dim=-1) weight = torch.sigmoid(self.weight_fc(combined)) out = weight * fusion_out + (1 - weight) * total out = self.fc1(out).permute(0, 2, 1) out = self.bn1(out).permute(0, 2, 1) out = torch.relu(out) out = self.fc2(out) recon_out = out[0][:esm_rep.shape[0]] recon_out = torch.softmax(recon_out, dim=-1) return recon_out class CrossAttention(nn.Module): def __init__(self, dim1, dim2, dropout=0.1): super(CrossAttention, self).__init__() self.dim1 = dim1 self.key = nn.Linear(dim2, dim1) self.value = nn.Linear(dim2, dim1) self.out = nn.Linear(dim1, dim1) def forward(self, Q, K, V): Q_proj = Q K_proj = self.key(K) # [len, dim1] V_proj = self.value(V) # [len, dim1] attention_scores = torch.matmul(Q_proj, K_proj.transpose(-2, -1)) # [len, len] # 由于是分块矩阵,所以可以直接相乘 attention_scores = attention_scores / (self.dim1 ** 0.5) # Scale by the square root of dim1 attention_probs = F.softmax(attention_scores, dim=-1) # Softmax over the last dimension (keys) context = torch.matmul(attention_probs, V_proj) # [len, dim1] output = self.out(context) # [len, dim1] return output class FeedForward(nn.Module): def __init__(self, dim, ff_dim=128, dropout=0.1, condition_dim=None): super(FeedForward, self).__init__() input_dim = dim + condition_dim if condition_dim is not None else dim self.fc1 = nn.Linear(input_dim, ff_dim) self.fc2 = nn.Linear(ff_dim, dim) self.dropout = nn.Dropout(dropout) self.relu = nn.ReLU() def forward(self, x, condition=None): if condition is not None: condition = condition.unsqueeze(1).expand(-1, x.size(1), -1) # [len, condition_dim] x = torch.cat([x, condition], dim=-1) # [len, dim + condition_dim] x = self.fc1(x) x = self.relu(x) x = self.dropout(x) x = self.fc2(x) return x class FunICross(nn.Module): def __init__(self, dim1, dim2, ff_dim=128, dropout=0.1, condition_dim=None): super(FunICross, self).__init__() self.attn = CrossAttention(dim1, dim2, dropout) self.attn_layer_norm = nn.LayerNorm(dim1) self.ff = FeedForward(dim1, ff_dim, dropout, condition_dim) self.ff_layer_norm = nn.LayerNorm(dim1) def forward(self, Q, K, V, condition=None): attn_output = self.attn(Q, K, V) Q = self.attn_layer_norm(Q + attn_output) ff_output = self.ff(Q, condition) # 把condition加到了feedforward的输入中 Q = self.ff_layer_norm(Q + ff_output) return Q