import math from math import pi as PI import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.parallel import torch.utils.data import torch_geometric.transforms as T from torch.nn import ModuleList, Parameter from torch_geometric.nn import HANConv, HEATConv, HGTConv, Linear from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear # from dataset import from torch_geometric.nn.inits import glorot, zeros from torch_geometric.utils import softmax from torch_scatter import scatter from util import get_angle, get_theta, triplets class Smodel(nn.Module): def __init__(self, h_channel=16,input_featuresize=32,localdepth=2,num_interactions=3,finaldepth=3,share='0',batchnorm="True"): super(Smodel,self).__init__() self.training=True self.h_channel = h_channel self.input_featuresize=input_featuresize self.localdepth = localdepth self.num_interactions=num_interactions self.finaldepth=finaldepth self.batchnorm = batchnorm self.activation=nn.ReLU() self.att = Parameter(torch.ones(4),requires_grad=True) num_gaussians=(1,1,1) self.mlp_geo = ModuleList() for i in range(self.localdepth): if i == 0: self.mlp_geo.append(Linear(sum(num_gaussians), h_channel)) else: self.mlp_geo.append(Linear(h_channel, h_channel)) if self.batchnorm == "True": self.mlp_geo.append(nn.BatchNorm1d(h_channel)) self.mlp_geo.append(self.activation) self.mlp_geo_backup = ModuleList() for i in range(self.localdepth): if i == 0: self.mlp_geo_backup.append(Linear(4, h_channel)) else: self.mlp_geo_backup.append(Linear(h_channel, h_channel)) if self.batchnorm == "True": self.mlp_geo_backup.append(nn.BatchNorm1d(h_channel)) self.mlp_geo_backup.append(self.activation) self.translinear=Linear(input_featuresize+1, self.h_channel) self.interactions= ModuleList() for i in range(self.num_interactions): block = SPNN( in_ch=self.input_featuresize, hidden_channels=self.h_channel, activation=self.activation, finaldepth=self.finaldepth, batchnorm=self.batchnorm, num_input_geofeature=self.h_channel ) self.interactions.append(block) self.reset_parameters() def reset_parameters(self): for lin in self.mlp_geo: if isinstance(lin, Linear): torch.nn.init.xavier_uniform_(lin.weight) lin.bias.data.fill_(0) for i in (self.interactions): i.reset_parameters() def single_forward(self, input_feature,coords,edge_index,edge_index_2rd, edx_jk, edx_ij,batch,num_edge_inside,edge_rep): if edge_rep: i, j, k = edge_index_2rd edge_index1,edge_index2= edge_index edge_index_all=torch.cat([edge_index1,edge_index2],1) distance_ij=(coords[j] - coords[i]).norm(p=2, dim=1) distance_jk=(coords[j] - coords[k]).norm(p=2, dim=1) theta_ijk = get_angle(coords[j] - coords[i], coords[k] - coords[j]) geo_encoding_1st=distance_ij[:,None] geo_encoding=torch.cat([geo_encoding_1st,distance_jk[:,None],theta_ijk[:,None]],dim=-1) else: coords_j = coords[edge_index[0]] coords_i = coords[edge_index[1]] geo_encoding=torch.cat([coords_j,coords_i],dim=-1) if edge_rep: for lin in self.mlp_geo: geo_encoding=lin(geo_encoding) else: for lin in self.mlp_geo_backup: geo_encoding=lin(geo_encoding) geo_encoding=torch.zeros_like(geo_encoding,device=geo_encoding.device,dtype=geo_encoding.dtype) node_feature= input_feature node_feature_list=[] for interaction in self.interactions: node_feature = interaction(node_feature,geo_encoding,edge_index_2rd,edx_jk,edx_ij,num_edge_inside,self.att) node_feature_list.append(node_feature) return node_feature_list def forward(self, input_feature, coords,edge_index,edge_index_2rd, edx_jk, edx_ij,batch,num_edge_inside,edge_rep): output=self.single_forward(input_feature,coords,edge_index,edge_index_2rd, edx_jk, edx_ij,batch,num_edge_inside,edge_rep) return output class SPNN(torch.nn.Module): def __init__( self, in_ch, hidden_channels, activation=torch.nn.ReLU(), finaldepth=3, batchnorm="True", num_input_geofeature=13 ): super(SPNN, self).__init__() self.activation = activation self.finaldepth = finaldepth self.batchnorm = batchnorm self.num_input_geofeature=num_input_geofeature self.WMLP_list = ModuleList() for _ in range(4): WMLP = ModuleList() for i in range(self.finaldepth + 1): if i == 0: WMLP.append(Linear(hidden_channels*3+num_input_geofeature, hidden_channels)) else: WMLP.append(Linear(hidden_channels, hidden_channels)) if self.batchnorm == "True": WMLP.append(nn.BatchNorm1d(hidden_channels)) WMLP.append(self.activation) self.WMLP_list.append(WMLP) self.reset_parameters() def reset_parameters(self): for mlp in self.WMLP_list: for lin in mlp: if isinstance(lin, Linear): torch.nn.init.xavier_uniform_(lin.weight) lin.bias.data.fill_(0) def forward(self, node_feature,geo_encoding,edge_index_2rd,edx_jk,edx_ij,num_edge_inside,att): i,j,k = edge_index_2rd if node_feature is None: concatenated_vector = geo_encoding else: node_attr_0st = node_feature[i] node_attr_1st = node_feature[j] node_attr_2 = node_feature[k] concatenated_vector = torch.cat( [ node_attr_0st, node_attr_1st,node_attr_2, geo_encoding, ], dim=-1, ) x_i = concatenated_vector edge1_edge1_mask = (edx_ij < num_edge_inside) & (edx_jk < num_edge_inside) edge1_edge2_mask = (edx_ij < num_edge_inside) & (edx_jk >= num_edge_inside) edge2_edge1_mask = (edx_ij >= num_edge_inside) & (edx_jk < num_edge_inside) edge2_edge2_mask = (edx_ij >= num_edge_inside) & (edx_jk >= num_edge_inside) masks=[edge1_edge1_mask,edge1_edge2_mask,edge2_edge1_mask,edge2_edge2_mask] x_output=torch.zeros(x_i.shape[0],self.WMLP_list[0][0].weight.shape[0],device=x_i.device) for index in range(4): WMLP=self.WMLP_list[index] x=x_i[masks[index]] for lin in WMLP: x=lin(x) x = F.leaky_relu(x)*att[index] x_output[masks[index]]+=x out_feature = scatter(x_output, i, dim=0, reduce='add') return out_feature class HGT(torch.nn.Module): def __init__(self, hidden_channels, out_channels, num_heads, num_layers): super().__init__() self.lin_dict = torch.nn.ModuleDict() for node_type in ["vertices"]: self.lin_dict[node_type] = Linear(-1, hidden_channels) self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = HGTConv(hidden_channels, hidden_channels, (['vertices'],[('vertices', 'inside', 'vertices'), ('vertices', 'apart', 'vertices')]), num_heads, group='sum') self.convs.append(conv) self.lin = Linear(hidden_channels, out_channels) def forward(self, x_dict, edge_index_dict): for node_type, x in x_dict.items(): x_dict[node_type]=self.lin_dict[node_type](x).relu_() for conv in self.convs: x_dict = conv(x_dict, edge_index_dict) return self.lin(x_dict['vertices']) class HAN(torch.nn.Module): def __init__(self, hidden_channels, out_channels, num_heads, num_layers): super().__init__() self.lin_dict = torch.nn.ModuleDict() for node_type in ["vertices"]: self.lin_dict[node_type] = Linear(-1, hidden_channels) self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = HANConv(hidden_channels, hidden_channels, (['vertices'],[('vertices', 'inside', 'vertices'), ('vertices', 'apart', 'vertices')]), num_heads) self.convs.append(conv) self.lin = Linear(hidden_channels, out_channels) def forward(self, x_dict, edge_index_dict): for node_type, x in x_dict.items(): x_dict[node_type]=self.lin_dict[node_type](x).relu_() for conv in self.convs: x_dict = conv(x_dict, edge_index_dict) return self.lin(x_dict['vertices'])