|
import argparse |
|
import copy |
|
import json |
|
import os |
|
import random |
|
import time |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pandas as pd |
|
import scipy |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
from matplotlib import cm |
|
from sklearn.metrics import (auc, explained_variance_score, f1_score, |
|
mean_absolute_error, mean_squared_error, |
|
precision_score, r2_score, recall_score, |
|
roc_auc_score, roc_curve) |
|
from torch.nn.functional import softmax |
|
from torch_geometric.utils import subgraph |
|
|
|
torch.autograd.set_detect_anomaly(True) |
|
import math |
|
import pickle |
|
import time |
|
from datetime import date, datetime, timedelta |
|
|
|
import torch.nn as nn |
|
import torch_geometric |
|
import torchvision.datasets |
|
import torchvision.models |
|
import torchvision.transforms as transforms |
|
from torch.utils.tensorboard import SummaryWriter |
|
from torch_geometric.nn import GIN, MLP, GATConv |
|
from torch_geometric.nn.pool import global_add_pool, global_mean_pool |
|
from torch_geometric.utils import add_self_loops |
|
|
|
import dataset |
|
import model_new |
|
import util |
|
from dataset import label_mapping, reverse_label_mapping |
|
from model_new import Smodel |
|
|
|
blue = lambda x: '\033[94m' + x + '\033[0m' |
|
red = lambda x: '\033[31m' + x + '\033[0m' |
|
green = lambda x: '\033[32m' + x + '\033[0m' |
|
yellow = lambda x: '\033[33m' + x + '\033[0m' |
|
greenline = lambda x: '\033[42m' + x + '\033[0m' |
|
yellowline = lambda x: '\033[43m' + x + '\033[0m' |
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--model',default="our", type=str) |
|
parser.add_argument('--train_batch', default=64, type=int) |
|
parser.add_argument('--test_batch', default=128, type=int) |
|
parser.add_argument('--share', type=str, default="0") |
|
parser.add_argument('--edge_rep', type=str, default="True") |
|
parser.add_argument('--batchnorm', type=str, default="True") |
|
parser.add_argument('--extent_norm', type=str, default="T") |
|
parser.add_argument('--spanning_tree', type=str, default="T") |
|
|
|
parser.add_argument('--loss_coef', default=0.1, type=float) |
|
parser.add_argument('--h_ch', default=512, type=int) |
|
parser.add_argument('--localdepth', type=int, default=1) |
|
parser.add_argument('--num_interactions', type=int, default=4) |
|
parser.add_argument('--finaldepth', type=int, default=4) |
|
parser.add_argument('--classifier_depth', type=int, default=4) |
|
parser.add_argument('--dropout', type=float, default=0.0) |
|
|
|
parser.add_argument('--dataset', type=str, default='mnist') |
|
parser.add_argument('--log', type=str, default="True") |
|
parser.add_argument('--test_per_round', type=int, default=10) |
|
parser.add_argument('--patience', type=int, default=30) |
|
parser.add_argument('--nepoch', type=int, default=301) |
|
parser.add_argument('--lr', type=float, default=1e-4) |
|
parser.add_argument('--manualSeed', type=str, default="False") |
|
parser.add_argument('--man_seed', type=int, default=12345) |
|
args = parser.parse_args() |
|
args.log=True if args.log=="True" else False |
|
args.edge_rep=True if args.edge_rep=="True" else False |
|
args.batchnorm=True if args.batchnorm=="True" else False |
|
args.save_dir=os.path.join('./save/',args.dataset,args.model) |
|
args.manualSeed=True if args.manualSeed=="True" else False |
|
return args |
|
|
|
args = get_args() |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
criterion=nn.CrossEntropyLoss() |
|
if args.dataset in ["mnist"]: |
|
x_out=90 |
|
args.data_dir='data/multi_mnist_with_index.pkl' |
|
elif args.dataset in ["mnist_sparse"]: |
|
x_out=90 |
|
args.data_dir='data/multi_mnist_sparse.pkl' |
|
elif args.dataset in ["building"]: |
|
x_out=100 |
|
args.data_dir='data/building_with_index.pkl' |
|
elif args.dataset in ["mbuilding"]: |
|
x_out=100 |
|
args.data_dir='data/mp_building.pkl' |
|
elif args.dataset in ["sbuilding"]: |
|
x_out=10 |
|
args.data_dir='data/single_building.pkl' |
|
elif args.dataset in ["smnist"]: |
|
x_out=10 |
|
args.data_dir='data/single_mnist.pkl' |
|
elif args.dataset in ["dbp"]: |
|
x_out=2 |
|
args.data_dir='data/triple_building_600.pkl' |
|
|
|
|
|
if args.model=="our": |
|
model=Smodel(h_channel=args.h_ch,input_featuresize=args.h_ch,\ |
|
localdepth=args.localdepth,num_interactions=args.num_interactions,finaldepth=args.finaldepth,share=args.share,batchnorm=args.batchnorm) |
|
mlpmodel=MLP(in_channels=args.h_ch*args.num_interactions, hidden_channels=args.h_ch,out_channels=x_out, num_layers=args.classifier_depth,dropout=args.dropout) |
|
|
|
elif args.model=="HGT": |
|
model=model_new.HGT(hidden_channels=args.h_ch, out_channels=args.h_ch, num_heads=2, num_layers=args.num_interactions) |
|
mlpmodel=MLP(in_channels=args.h_ch, hidden_channels=args.h_ch,out_channels=x_out, num_layers=args.classifier_depth,dropout=args.dropout) |
|
elif args.model=="HAN": |
|
model=model_new.HAN(hidden_channels=args.h_ch, out_channels=args.h_ch, num_heads=2, num_layers=args.num_interactions) |
|
mlpmodel=MLP(in_channels=args.h_ch, hidden_channels=args.h_ch,out_channels=x_out, num_layers=args.classifier_depth,dropout=args.dropout) |
|
|
|
model.to(device), mlpmodel.to(device) |
|
opt_list=list(model.parameters())+list(mlpmodel.parameters()) |
|
|
|
optimizer = torch.optim.Adam( opt_list, lr=args.lr) |
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=args.patience, min_lr=1e-8) |
|
|
|
def contrastive_loss(embeddings,labels,margin): |
|
|
|
positive_mask = labels.view(-1, 1) == labels.view(1, -1) |
|
negative_mask = ~positive_mask |
|
|
|
|
|
num_positive_pairs = positive_mask.sum() - labels.shape[0] |
|
num_negative_pairs = negative_mask.sum() |
|
|
|
|
|
if num_negative_pairs==0 or num_positive_pairs== 0: |
|
print("all pos or neg") |
|
return torch.tensor(0, dtype=torch.float) |
|
|
|
distances = torch.cdist(embeddings, embeddings)/np.sqrt(embeddings.shape[1]) |
|
|
|
if num_positive_pairs>num_negative_pairs: |
|
|
|
positive_indices = torch.nonzero(positive_mask) |
|
random_positive_indices = torch.randperm(len(positive_indices))[:num_negative_pairs] |
|
selected_positive_indices = positive_indices[random_positive_indices] |
|
|
|
|
|
negative_mask.fill_diagonal_(False) |
|
negative_distances = distances[negative_mask].view(-1, 1) |
|
positive_distances = distances[selected_positive_indices[:,0],selected_positive_indices[:,1]].view(-1, 1) |
|
else: |
|
|
|
negative_indices = torch.nonzero(negative_mask) |
|
random_negative_indices = torch.randperm(len(negative_indices))[:num_positive_pairs] |
|
selected_negative_indices = negative_indices[random_negative_indices] |
|
|
|
|
|
positive_mask.fill_diagonal_(False) |
|
positive_distances = distances[positive_mask].view(-1, 1) |
|
negative_distances = distances[selected_negative_indices[:,0],selected_negative_indices[:,1]].view(-1, 1) |
|
|
|
|
|
loss = (positive_distances - negative_distances + margin).clamp(min=0).mean() |
|
return loss |
|
|
|
def forward_HGT(data,model,mlpmodel): |
|
data = data.to(device) |
|
x,batch=data.pos, data['vertices'].batch |
|
data["vertices"]['x']=data.pos |
|
label=data.y.long().view(-1) |
|
|
|
optimizer.zero_grad() |
|
|
|
output=model(data.x_dict, data.edge_index_dict) |
|
if args.dataset in ["dbp"]: |
|
graph_embeddings=global_add_pool(output,batch) |
|
else: |
|
graph_embeddings=global_add_pool(output,batch) |
|
graph_embeddings.clamp_(max=1e6) |
|
c_loss=contrastive_loss(graph_embeddings,label,margin=1) |
|
output=mlpmodel(graph_embeddings) |
|
|
|
|
|
loss = criterion(output, label) |
|
loss+=c_loss*args.loss_coef |
|
return loss,c_loss*args.loss_coef,output,label |
|
|
|
def forward(data,model,mlpmodel): |
|
data = data.to(device) |
|
edge_index1=data['vertices', 'inside', 'vertices']['edge_index'] |
|
edge_index2=data['vertices', 'apart', 'vertices']['edge_index'] |
|
combined_edge_index=torch.cat([data['vertices', 'inside', 'vertices']['edge_index'],data['vertices', 'apart', 'vertices']['edge_index']],1) |
|
num_edge_inside=edge_index1.shape[1] |
|
|
|
if args.spanning_tree == 'T': |
|
edge_weight=torch.rand(combined_edge_index.shape[1]) + 1 |
|
undirected_spanning_edge = util.build_spanning_tree_edge(combined_edge_index, edge_weight,num_nodes=data.pos.shape[0]) |
|
|
|
edge_set_1 = set(map(tuple, edge_index2.t().tolist())) |
|
edge_set_2 = set(map(tuple, undirected_spanning_edge.t().tolist())) |
|
|
|
common_edges = edge_set_1.intersection(edge_set_2) |
|
common_edges_tensor = torch.tensor(list(common_edges), dtype=torch.long).t().to(device) |
|
spanning_edge=torch.cat([edge_index1,common_edges_tensor],1) |
|
combined_edge_index=spanning_edge |
|
x,batch=data.pos, data['vertices'].batch |
|
label=data.y.long().view(-1) |
|
|
|
num_nodes=x.shape[0] |
|
edge_index_2rd, num_triplets_real, edx_jk, edx_ij = util.triplets(combined_edge_index, num_nodes) |
|
optimizer.zero_grad() |
|
input_feature=torch.zeros([x.shape[0],args.h_ch],device=device) |
|
output=model(input_feature,x,[edge_index1,edge_index2], edge_index_2rd,edx_jk, edx_ij,batch,num_edge_inside,args.edge_rep) |
|
output=torch.cat(output,dim=1) |
|
if args.dataset in ["dbp"]: |
|
graph_embeddings=global_add_pool(output,batch) |
|
else: |
|
graph_embeddings=global_add_pool(output,batch) |
|
graph_embeddings.clamp_(max=1e6) |
|
c_loss=contrastive_loss(graph_embeddings,label,margin=1) |
|
output=mlpmodel(graph_embeddings) |
|
|
|
loss = criterion(output, label) |
|
loss+=c_loss*args.loss_coef |
|
return loss,c_loss*args.loss_coef,output,label |
|
def train(train_Loader,model,mlpmodel ): |
|
epochloss=0 |
|
epochcloss=0 |
|
y_hat, y_true,y_hat_logit = [], [], [] |
|
optimizer.zero_grad() |
|
model.train() |
|
mlpmodel.train() |
|
for i,data in enumerate(train_Loader): |
|
if args.model=="our": |
|
loss,c_loss,output,label =forward(data,model,mlpmodel) |
|
elif args.model in ["HGT","HAN"]: |
|
loss,c_loss,output,label =forward_HGT(data,model,mlpmodel) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
epochloss+=loss.detach().cpu() |
|
epochcloss+=c_loss.detach().cpu() |
|
|
|
_, pred = output.topk(1, dim=1, largest=True, sorted=True) |
|
pred,label,output=pred.cpu(),label.cpu(),output.cpu() |
|
y_hat += list(pred.detach().numpy().reshape(-1)) |
|
y_true += list(label.detach().numpy().reshape(-1)) |
|
y_hat_logit+=list(output.detach().numpy()) |
|
return epochloss.item()/len(train_Loader),epochcloss.item()/len(train_Loader),y_hat, y_true,y_hat_logit |
|
|
|
def test(loader,model,mlpmodel ): |
|
y_hat, y_true,y_hat_logit = [], [], [] |
|
loss_total, pred_num = 0, 0 |
|
model.eval() |
|
mlpmodel.eval() |
|
with torch.no_grad(): |
|
for data in loader: |
|
if args.model=="our": |
|
loss,c_loss,output,label =forward(data,model,mlpmodel) |
|
elif args.model in ["HGT","HAN"]: |
|
loss,c_loss,output,label =forward_HGT(data,model,mlpmodel) |
|
|
|
_, pred = output.topk(1, dim=1, largest=True, sorted=True) |
|
pred,label,output=pred.cpu(),label.cpu(),output.cpu() |
|
y_hat += list(pred.detach().numpy().reshape(-1)) |
|
y_true += list(label.detach().numpy().reshape(-1)) |
|
y_hat_logit+=list(output.detach().numpy()) |
|
|
|
pred_num += len(label.reshape(-1, 1)) |
|
loss_total += loss.detach() * len(label.reshape(-1, 1)) |
|
return loss_total/pred_num,y_hat, y_true, y_hat_logit |
|
def main(args,train_Loader,val_Loader,test_Loader): |
|
best_val_trigger = -1 |
|
old_lr=1e3 |
|
suffix="{}{}-{}:{}:{}".format(datetime.now().strftime("%h"), |
|
datetime.now().strftime("%d"), |
|
datetime.now().strftime("%H"), |
|
datetime.now().strftime("%M"), |
|
datetime.now().strftime("%S")) |
|
if args.log: writer = SummaryWriter(os.path.join(tensorboard_dir,suffix)) |
|
|
|
for epoch in range(args.nepoch): |
|
train_loss,train_closs,y_hat, y_true,y_hat_logit=train(train_Loader,model,mlpmodel ) |
|
|
|
train_acc=util.calculate(y_hat,y_true,y_hat_logit) |
|
try:util.record({"loss":train_loss,"closs":train_closs,"acc":train_acc},epoch,writer,"Train") |
|
except: pass |
|
util.print_1(epoch,'Train',{"loss":train_loss,"closs":train_closs,"acc":train_acc}) |
|
if epoch % args.test_per_round == 0: |
|
val_loss, yhat_val, ytrue_val, yhatlogit_val = test(val_Loader,model,mlpmodel ) |
|
test_loss, yhat_test, ytrue_test, yhatlogit_test = test(test_Loader,model,mlpmodel ) |
|
val_acc=util.calculate(yhat_val,ytrue_val,yhatlogit_val) |
|
try:util.record({"loss":val_loss,"acc":val_acc},epoch,writer,"Val") |
|
except: pass |
|
util.print_1(epoch,'Val',{"loss":val_loss,"acc":val_acc},color=blue) |
|
test_acc=util.calculate(yhat_test,ytrue_test,yhatlogit_test) |
|
try:util.record({"loss":test_loss,"acc":test_acc},epoch,writer,"Test") |
|
except: pass |
|
util.print_1(epoch,'Test',{"loss":test_loss,"acc":test_acc},color=blue) |
|
val_trigger=val_acc |
|
if val_trigger > best_val_trigger: |
|
best_val_trigger = val_trigger |
|
best_model = copy.deepcopy(model) |
|
best_mlpmodel=copy.deepcopy(mlpmodel) |
|
best_info=[epoch,val_trigger] |
|
""" |
|
update lr when epoch≥30 |
|
""" |
|
if epoch >= 30: |
|
lr = scheduler.optimizer.param_groups[0]['lr'] |
|
if old_lr!=lr: |
|
print(red('lr'), epoch, (lr), sep=', ') |
|
old_lr=lr |
|
scheduler.step(val_trigger) |
|
""" |
|
use best model to get best model result |
|
""" |
|
val_loss, yhat_val, ytrue_val, yhat_logit_val = test(val_Loader,best_model,best_mlpmodel) |
|
test_loss, yhat_test, ytrue_test, yhat_logit_test= test(test_Loader,best_model,best_mlpmodel) |
|
|
|
val_acc=util.calculate(yhat_val,ytrue_val,yhat_logit_val) |
|
util.print_1(best_info[0],'BestVal',{"loss":val_loss,"acc":val_acc},color=blue) |
|
test_acc=util.calculate(yhat_test,ytrue_test,yhat_logit_test) |
|
util.print_1(best_info[0],'BestTest',{"loss":test_loss,"acc":test_acc},color=blue) |
|
if args.model=="our":print(best_model.att) |
|
|
|
""" |
|
save training info and best result |
|
""" |
|
result_file=os.path.join(info_dir, suffix) |
|
with open(result_file, 'w') as f: |
|
print("Random Seed: ", Seed,file=f) |
|
print(f"acc val : {val_acc:.3f}, Test : {test_acc:.3f}", file=f) |
|
print(f"Best info: {best_info}", file=f) |
|
for i in [[a,getattr(args, a)] for a in args.__dict__]: |
|
print(i,sep='\n',file=f) |
|
to_save_dict={'model':best_model.state_dict(),'mlpmodel':best_mlpmodel.state_dict(),'args':args,'labels':ytrue_test,'yhat':yhat_test,'yhat_logit':yhat_logit_test} |
|
torch.save(to_save_dict, os.path.join(model_dir,suffix+'.pth') ) |
|
print("done") |
|
|
|
if __name__ == '__main__': |
|
""" |
|
build dir |
|
""" |
|
if not os.path.exists(args.save_dir): |
|
os.makedirs(args.save_dir,exist_ok=True) |
|
tensorboard_dir=os.path.join(args.save_dir,'log') |
|
if not os.path.exists(tensorboard_dir): |
|
os.makedirs(tensorboard_dir,exist_ok=True) |
|
model_dir=os.path.join(args.save_dir,'model') |
|
if not os.path.exists(model_dir): |
|
os.makedirs(model_dir,exist_ok=True) |
|
info_dir=os.path.join(args.save_dir,'info') |
|
if not os.path.exists(info_dir): |
|
os.makedirs(info_dir,exist_ok=True) |
|
|
|
Seed = 0 |
|
test_ratio=0.2 |
|
print("data splitting Random Seed: ", Seed) |
|
if args.dataset in ['mnist',"mnist_sparse"]: |
|
train_ds,val_ds,test_ds=dataset.get_mnist_dataset(args.data_dir,Seed,test_ratio=test_ratio) |
|
elif args.dataset in ['building']: |
|
train_ds,val_ds,test_ds=dataset.get_building_dataset(args.data_dir,Seed,test_ratio=test_ratio) |
|
elif args.dataset in ['mbuilding']: |
|
train_ds,val_ds,test_ds=dataset.get_mbuilding_dataset(args.data_dir,Seed,test_ratio=test_ratio) |
|
elif args.dataset in ['sbuilding']: |
|
train_ds,val_ds,test_ds=dataset.get_sbuilding_dataset(args.data_dir,Seed,test_ratio=test_ratio) |
|
elif args.dataset in ['smnist']: |
|
train_ds,val_ds,test_ds=dataset.get_smnist_dataset(args.data_dir,Seed,test_ratio=test_ratio) |
|
elif args.dataset in ['dbp']: |
|
train_ds,val_ds,test_ds=dataset.get_dbp_dataset(args.data_dir,Seed,test_ratio=test_ratio) |
|
if args.extent_norm=="T": |
|
train_ds= dataset.affine_transform_to_range(train_ds,target_range=(-1, 1)) |
|
val_ds= dataset.affine_transform_to_range(val_ds,target_range=(-1, 1)) |
|
test_ds= dataset.affine_transform_to_range(test_ds,target_range=(-1, 1)) |
|
|
|
train_loader = torch_geometric.loader.DataLoader(train_ds,batch_size=args.train_batch, shuffle=False,pin_memory=True,drop_last=True) |
|
val_loader = torch_geometric.loader.DataLoader(val_ds, batch_size=args.test_batch, shuffle=False, pin_memory=True) |
|
test_loader = torch_geometric.loader.DataLoader(test_ds,batch_size=args.test_batch, shuffle=False,pin_memory=True) |
|
""" |
|
set model seed |
|
""" |
|
Seed = args.man_seed if args.manualSeed else random.randint(1, 10000) |
|
Seed=3407 |
|
print("Random Seed: ", Seed) |
|
print(args) |
|
random.seed(Seed) |
|
torch.manual_seed(Seed) |
|
np.random.seed(Seed) |
|
main(args,train_loader,val_loader,test_loader) |
|
|