M3Site / model /model.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import BatchNorm
from utils.util_classes import CenterLoss, InterClassLoss
from model.egnn.network import EGNN
class AP_align_fuse_graph(torch.nn.Module):
def __init__(self, config, hidden_size=256):
super(AP_align_fuse_graph, self).__init__()
self.config = config
self.seq_max_length = config.dataset.seq_max_length
if '3' in config.dataset.lm:
self.embedding_dim = 1536
elif 't5' in config.dataset.lm:
self.embedding_dim = 1024
else:
self.embedding_dim = 1280
self.egnn_model = EGNN(config)
self.egnn_out_dim = self.config.egnn.output_dim
self.num_classes = 7
self.fc1 = nn.Linear(self.embedding_dim+self.egnn_out_dim, hidden_size)
self.bn1 = BatchNorm(hidden_size)
self.fc2 = nn.Linear(hidden_size, self.num_classes)
self.funicross1 = FunICross(self.egnn_out_dim, self.embedding_dim, condition_dim=768)
self.funicross2 = FunICross(self.embedding_dim, self.egnn_out_dim, condition_dim=768)
self.weight_fc = nn.Linear((self.embedding_dim+self.egnn_out_dim) * 2, 1)
self.center_loss = CenterLoss(num_classes=self.num_classes, feat_dim=7)
self.inter_loss = InterClassLoss(margin=0.1)
self.ab_egnn = nn.Linear(self.egnn_out_dim, self.embedding_dim+self.egnn_out_dim)
self.ab_esm = nn.Linear(self.embedding_dim, self.embedding_dim+self.egnn_out_dim)
def forward(self, data):
esm_rep, batch, func = data.esm_rep, data.batch, data.func
graphs = 1
egnn_output = self.egnn_model(data) # [nodes, 16]
# esm_data = torch.zeros(graphs, 1024, 1280).to(esm_rep.device) # [1, 1024, 1280]
# egnn_data = torch.zeros(graphs, 1024, self.egnn_out_dim).to(egnn_output.device) # [1, 1024, 16]
func_data = func.reshape(graphs, 768) # [1, 768]
# for graph_idx in range(graphs):
# mask = (batch == graph_idx)
# esm_data[graph_idx][:esm_rep[mask].shape[0]] = esm_rep[mask]
# egnn_data[graph_idx][:egnn_output[mask].shape[0]] = egnn_output[mask]
esm_data = F.pad(esm_rep, (0, 0, 0, 1024-esm_rep.shape[0]), value=0).unsqueeze(0)
egnn_data = F.pad(egnn_output, (0, 0, 0, 1024-egnn_output.shape[0]), value=0).unsqueeze(0)
total = torch.cat([esm_data, egnn_data], dim=-1) # [graphs, 1024, 1280+16]
stru_seq_seq = self.funicross1(egnn_data, esm_data, esm_data, func_data) # [graphs, 1024, 16]
seq_stru_stru = self.funicross2(esm_data, egnn_data, egnn_data, func_data) # [graphs, 1024, 1280]
fusion_out = torch.cat([stru_seq_seq, seq_stru_stru], dim=-1) # [graphs, 1024, 1280+16]
combined = torch.cat([fusion_out, total], dim=-1)
weight = torch.sigmoid(self.weight_fc(combined))
out = weight * fusion_out + (1 - weight) * total
out = self.fc1(out).permute(0, 2, 1)
out = self.bn1(out).permute(0, 2, 1)
out = torch.relu(out)
out = self.fc2(out)
recon_out = out[0][:esm_rep.shape[0]]
recon_out = torch.softmax(recon_out, dim=-1)
return recon_out
class CrossAttention(nn.Module):
def __init__(self, dim1, dim2, dropout=0.1):
super(CrossAttention, self).__init__()
self.dim1 = dim1
self.key = nn.Linear(dim2, dim1)
self.value = nn.Linear(dim2, dim1)
self.out = nn.Linear(dim1, dim1)
def forward(self, Q, K, V):
Q_proj = Q
K_proj = self.key(K) # [len, dim1]
V_proj = self.value(V) # [len, dim1]
attention_scores = torch.matmul(Q_proj, K_proj.transpose(-2, -1)) # [len, len] # 由于是分块矩阵,所以可以直接相乘
attention_scores = attention_scores / (self.dim1 ** 0.5) # Scale by the square root of dim1
attention_probs = F.softmax(attention_scores, dim=-1) # Softmax over the last dimension (keys)
context = torch.matmul(attention_probs, V_proj) # [len, dim1]
output = self.out(context) # [len, dim1]
return output
class FeedForward(nn.Module):
def __init__(self, dim, ff_dim=128, dropout=0.1, condition_dim=None):
super(FeedForward, self).__init__()
input_dim = dim + condition_dim if condition_dim is not None else dim
self.fc1 = nn.Linear(input_dim, ff_dim)
self.fc2 = nn.Linear(ff_dim, dim)
self.dropout = nn.Dropout(dropout)
self.relu = nn.ReLU()
def forward(self, x, condition=None):
if condition is not None:
condition = condition.unsqueeze(1).expand(-1, x.size(1), -1) # [len, condition_dim]
x = torch.cat([x, condition], dim=-1) # [len, dim + condition_dim]
x = self.fc1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x)
return x
class FunICross(nn.Module):
def __init__(self, dim1, dim2, ff_dim=128, dropout=0.1, condition_dim=None):
super(FunICross, self).__init__()
self.attn = CrossAttention(dim1, dim2, dropout)
self.attn_layer_norm = nn.LayerNorm(dim1)
self.ff = FeedForward(dim1, ff_dim, dropout, condition_dim)
self.ff_layer_norm = nn.LayerNorm(dim1)
def forward(self, Q, K, V, condition=None):
attn_output = self.attn(Q, K, V)
Q = self.attn_layer_norm(Q + attn_output)
ff_output = self.ff(Q, condition) # 把condition加到了feedforward的输入中
Q = self.ff_layer_norm(Q + ff_output)
return Q