PolygonGNN / train_new.py
Dzy6's picture
init
e551dda
import argparse
import copy
import json
import os
import random
import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import torch
import torch.nn.functional as F
import torch.optim as optim
from matplotlib import cm
from sklearn.metrics import (auc, explained_variance_score, f1_score,
mean_absolute_error, mean_squared_error,
precision_score, r2_score, recall_score,
roc_auc_score, roc_curve)
from torch.nn.functional import softmax
from torch_geometric.utils import subgraph
torch.autograd.set_detect_anomaly(True)
import math
import pickle
import time
from datetime import date, datetime, timedelta
import torch.nn as nn
import torch_geometric
import torchvision.datasets
import torchvision.models
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.nn import GIN, MLP, GATConv
from torch_geometric.nn.pool import global_add_pool, global_mean_pool
from torch_geometric.utils import add_self_loops
import dataset
import model_new
import util
from dataset import label_mapping, reverse_label_mapping
from model_new import Smodel
blue = lambda x: '\033[94m' + x + '\033[0m'
red = lambda x: '\033[31m' + x + '\033[0m'
green = lambda x: '\033[32m' + x + '\033[0m'
yellow = lambda x: '\033[33m' + x + '\033[0m'
greenline = lambda x: '\033[42m' + x + '\033[0m'
yellowline = lambda x: '\033[43m' + x + '\033[0m'
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model',default="our", type=str)
parser.add_argument('--train_batch', default=64, type=int)
parser.add_argument('--test_batch', default=128, type=int)
parser.add_argument('--share', type=str, default="0")
parser.add_argument('--edge_rep', type=str, default="True")
parser.add_argument('--batchnorm', type=str, default="True")
parser.add_argument('--extent_norm', type=str, default="T")
parser.add_argument('--spanning_tree', type=str, default="T")
parser.add_argument('--loss_coef', default=0.1, type=float)
parser.add_argument('--h_ch', default=512, type=int)
parser.add_argument('--localdepth', type=int, default=1)
parser.add_argument('--num_interactions', type=int, default=4)
parser.add_argument('--finaldepth', type=int, default=4)
parser.add_argument('--classifier_depth', type=int, default=4)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--dataset', type=str, default='mnist')
parser.add_argument('--log', type=str, default="True")
parser.add_argument('--test_per_round', type=int, default=10)
parser.add_argument('--patience', type=int, default=30) #scheduler
parser.add_argument('--nepoch', type=int, default=301)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--manualSeed', type=str, default="False")
parser.add_argument('--man_seed', type=int, default=12345)
args = parser.parse_args()
args.log=True if args.log=="True" else False
args.edge_rep=True if args.edge_rep=="True" else False
args.batchnorm=True if args.batchnorm=="True" else False
args.save_dir=os.path.join('./save/',args.dataset,args.model)
args.manualSeed=True if args.manualSeed=="True" else False
return args
args = get_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion=nn.CrossEntropyLoss()
if args.dataset in ["mnist"]:
x_out=90
args.data_dir='data/multi_mnist_with_index.pkl'
elif args.dataset in ["mnist_sparse"]:
x_out=90
args.data_dir='data/multi_mnist_sparse.pkl'
elif args.dataset in ["building"]:
x_out=100
args.data_dir='data/building_with_index.pkl'
elif args.dataset in ["mbuilding"]:
x_out=100
args.data_dir='data/mp_building.pkl'
elif args.dataset in ["sbuilding"]:
x_out=10
args.data_dir='data/single_building.pkl'
elif args.dataset in ["smnist"]:
x_out=10
args.data_dir='data/single_mnist.pkl'
elif args.dataset in ["dbp"]:
x_out=2
args.data_dir='data/triple_building_600.pkl'
if args.model=="our":
model=Smodel(h_channel=args.h_ch,input_featuresize=args.h_ch,\
localdepth=args.localdepth,num_interactions=args.num_interactions,finaldepth=args.finaldepth,share=args.share,batchnorm=args.batchnorm)
mlpmodel=MLP(in_channels=args.h_ch*args.num_interactions, hidden_channels=args.h_ch,out_channels=x_out, num_layers=args.classifier_depth,dropout=args.dropout)
elif args.model=="HGT":
model=model_new.HGT(hidden_channels=args.h_ch, out_channels=args.h_ch, num_heads=2, num_layers=args.num_interactions)
mlpmodel=MLP(in_channels=args.h_ch, hidden_channels=args.h_ch,out_channels=x_out, num_layers=args.classifier_depth,dropout=args.dropout)
elif args.model=="HAN":
model=model_new.HAN(hidden_channels=args.h_ch, out_channels=args.h_ch, num_heads=2, num_layers=args.num_interactions)
mlpmodel=MLP(in_channels=args.h_ch, hidden_channels=args.h_ch,out_channels=x_out, num_layers=args.classifier_depth,dropout=args.dropout)
model.to(device), mlpmodel.to(device)
opt_list=list(model.parameters())+list(mlpmodel.parameters())
optimizer = torch.optim.Adam( opt_list, lr=args.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=args.patience, min_lr=1e-8)
def contrastive_loss(embeddings,labels,margin):
positive_mask = labels.view(-1, 1) == labels.view(1, -1)
negative_mask = ~positive_mask
# Calculate the number of positive and negative pairs
num_positive_pairs = positive_mask.sum() - labels.shape[0]
num_negative_pairs = negative_mask.sum()
# If there are no negative pairs, return a placeholder loss
if num_negative_pairs==0 or num_positive_pairs== 0:
print("all pos or neg")
return torch.tensor(0, dtype=torch.float)
# Calculate the pairwise Euclidean distances between embeddings
distances = torch.cdist(embeddings, embeddings)/np.sqrt(embeddings.shape[1])
if num_positive_pairs>num_negative_pairs:
# Sample an equal number of + pairs
positive_indices = torch.nonzero(positive_mask)
random_positive_indices = torch.randperm(len(positive_indices))[:num_negative_pairs]
selected_positive_indices = positive_indices[random_positive_indices]
# Select corresponding negative pairs
negative_mask.fill_diagonal_(False)
negative_distances = distances[negative_mask].view(-1, 1)
positive_distances = distances[selected_positive_indices[:,0],selected_positive_indices[:,1]].view(-1, 1)
else: # case for most datasets
# Sample an equal number of - pairs
negative_indices = torch.nonzero(negative_mask)
random_negative_indices = torch.randperm(len(negative_indices))[:num_positive_pairs]
selected_negative_indices = negative_indices[random_negative_indices]
# Select corresponding positive pairs
positive_mask.fill_diagonal_(False)
positive_distances = distances[positive_mask].view(-1, 1)
negative_distances = distances[selected_negative_indices[:,0],selected_negative_indices[:,1]].view(-1, 1)
# Calculate the loss for positive and negative pairs
loss = (positive_distances - negative_distances + margin).clamp(min=0).mean()
return loss
def forward_HGT(data,model,mlpmodel):
data = data.to(device)
x,batch=data.pos, data['vertices'].batch
data["vertices"]['x']=data.pos
label=data.y.long().view(-1)
optimizer.zero_grad()
output=model(data.x_dict, data.edge_index_dict)
if args.dataset in ["dbp"]:
graph_embeddings=global_add_pool(output,batch)
else:
graph_embeddings=global_add_pool(output,batch)
graph_embeddings.clamp_(max=1e6)
c_loss=contrastive_loss(graph_embeddings,label,margin=1)
output=mlpmodel(graph_embeddings)
# log_probs = F.log_softmax(output, dim=1)
loss = criterion(output, label)
loss+=c_loss*args.loss_coef
return loss,c_loss*args.loss_coef,output,label
def forward(data,model,mlpmodel):
data = data.to(device)
edge_index1=data['vertices', 'inside', 'vertices']['edge_index']
edge_index2=data['vertices', 'apart', 'vertices']['edge_index']
combined_edge_index=torch.cat([data['vertices', 'inside', 'vertices']['edge_index'],data['vertices', 'apart', 'vertices']['edge_index']],1)
num_edge_inside=edge_index1.shape[1]
if args.spanning_tree == 'T':
edge_weight=torch.rand(combined_edge_index.shape[1]) + 1
undirected_spanning_edge = util.build_spanning_tree_edge(combined_edge_index, edge_weight,num_nodes=data.pos.shape[0])
edge_set_1 = set(map(tuple, edge_index2.t().tolist()))
edge_set_2 = set(map(tuple, undirected_spanning_edge.t().tolist()))
common_edges = edge_set_1.intersection(edge_set_2)
common_edges_tensor = torch.tensor(list(common_edges), dtype=torch.long).t().to(device)
spanning_edge=torch.cat([edge_index1,common_edges_tensor],1)
combined_edge_index=spanning_edge
x,batch=data.pos, data['vertices'].batch
label=data.y.long().view(-1)
num_nodes=x.shape[0]
edge_index_2rd, num_triplets_real, edx_jk, edx_ij = util.triplets(combined_edge_index, num_nodes)
optimizer.zero_grad()
input_feature=torch.zeros([x.shape[0],args.h_ch],device=device)
output=model(input_feature,x,[edge_index1,edge_index2], edge_index_2rd,edx_jk, edx_ij,batch,num_edge_inside,args.edge_rep)
output=torch.cat(output,dim=1)
if args.dataset in ["dbp"]:
graph_embeddings=global_add_pool(output,batch)
else:
graph_embeddings=global_add_pool(output,batch)
graph_embeddings.clamp_(max=1e6)
c_loss=contrastive_loss(graph_embeddings,label,margin=1)
output=mlpmodel(graph_embeddings)
loss = criterion(output, label)
loss+=c_loss*args.loss_coef
return loss,c_loss*args.loss_coef,output,label
def train(train_Loader,model,mlpmodel ):
epochloss=0
epochcloss=0
y_hat, y_true,y_hat_logit = [], [], []
optimizer.zero_grad()
model.train()
mlpmodel.train()
for i,data in enumerate(train_Loader):
if args.model=="our":
loss,c_loss,output,label =forward(data,model,mlpmodel)
elif args.model in ["HGT","HAN"]:
loss,c_loss,output,label =forward_HGT(data,model,mlpmodel)
loss.backward()
optimizer.step()
epochloss+=loss.detach().cpu()
epochcloss+=c_loss.detach().cpu()
_, pred = output.topk(1, dim=1, largest=True, sorted=True)
pred,label,output=pred.cpu(),label.cpu(),output.cpu()
y_hat += list(pred.detach().numpy().reshape(-1))
y_true += list(label.detach().numpy().reshape(-1))
y_hat_logit+=list(output.detach().numpy())
return epochloss.item()/len(train_Loader),epochcloss.item()/len(train_Loader),y_hat, y_true,y_hat_logit
def test(loader,model,mlpmodel ):
y_hat, y_true,y_hat_logit = [], [], []
loss_total, pred_num = 0, 0
model.eval()
mlpmodel.eval()
with torch.no_grad():
for data in loader:
if args.model=="our":
loss,c_loss,output,label =forward(data,model,mlpmodel)
elif args.model in ["HGT","HAN"]:
loss,c_loss,output,label =forward_HGT(data,model,mlpmodel)
_, pred = output.topk(1, dim=1, largest=True, sorted=True)
pred,label,output=pred.cpu(),label.cpu(),output.cpu()
y_hat += list(pred.detach().numpy().reshape(-1))
y_true += list(label.detach().numpy().reshape(-1))
y_hat_logit+=list(output.detach().numpy())
pred_num += len(label.reshape(-1, 1))
loss_total += loss.detach() * len(label.reshape(-1, 1))
return loss_total/pred_num,y_hat, y_true, y_hat_logit
def main(args,train_Loader,val_Loader,test_Loader):
best_val_trigger = -1
old_lr=1e3
suffix="{}{}-{}:{}:{}".format(datetime.now().strftime("%h"),
datetime.now().strftime("%d"),
datetime.now().strftime("%H"),
datetime.now().strftime("%M"),
datetime.now().strftime("%S"))
if args.log: writer = SummaryWriter(os.path.join(tensorboard_dir,suffix))
for epoch in range(args.nepoch):
train_loss,train_closs,y_hat, y_true,y_hat_logit=train(train_Loader,model,mlpmodel )
train_acc=util.calculate(y_hat,y_true,y_hat_logit)
try:util.record({"loss":train_loss,"closs":train_closs,"acc":train_acc},epoch,writer,"Train")
except: pass
util.print_1(epoch,'Train',{"loss":train_loss,"closs":train_closs,"acc":train_acc})
if epoch % args.test_per_round == 0:
val_loss, yhat_val, ytrue_val, yhatlogit_val = test(val_Loader,model,mlpmodel )
test_loss, yhat_test, ytrue_test, yhatlogit_test = test(test_Loader,model,mlpmodel )
val_acc=util.calculate(yhat_val,ytrue_val,yhatlogit_val)
try:util.record({"loss":val_loss,"acc":val_acc},epoch,writer,"Val")
except: pass
util.print_1(epoch,'Val',{"loss":val_loss,"acc":val_acc},color=blue)
test_acc=util.calculate(yhat_test,ytrue_test,yhatlogit_test)
try:util.record({"loss":test_loss,"acc":test_acc},epoch,writer,"Test")
except: pass
util.print_1(epoch,'Test',{"loss":test_loss,"acc":test_acc},color=blue)
val_trigger=val_acc
if val_trigger > best_val_trigger:
best_val_trigger = val_trigger
best_model = copy.deepcopy(model)
best_mlpmodel=copy.deepcopy(mlpmodel)
best_info=[epoch,val_trigger]
"""
update lr when epoch≥30
"""
if epoch >= 30:
lr = scheduler.optimizer.param_groups[0]['lr']
if old_lr!=lr:
print(red('lr'), epoch, (lr), sep=', ')
old_lr=lr
scheduler.step(val_trigger)
"""
use best model to get best model result
"""
val_loss, yhat_val, ytrue_val, yhat_logit_val = test(val_Loader,best_model,best_mlpmodel)
test_loss, yhat_test, ytrue_test, yhat_logit_test= test(test_Loader,best_model,best_mlpmodel)
val_acc=util.calculate(yhat_val,ytrue_val,yhat_logit_val)
util.print_1(best_info[0],'BestVal',{"loss":val_loss,"acc":val_acc},color=blue)
test_acc=util.calculate(yhat_test,ytrue_test,yhat_logit_test)
util.print_1(best_info[0],'BestTest',{"loss":test_loss,"acc":test_acc},color=blue)
if args.model=="our":print(best_model.att)
"""
save training info and best result
"""
result_file=os.path.join(info_dir, suffix)
with open(result_file, 'w') as f:
print("Random Seed: ", Seed,file=f)
print(f"acc val : {val_acc:.3f}, Test : {test_acc:.3f}", file=f)
print(f"Best info: {best_info}", file=f)
for i in [[a,getattr(args, a)] for a in args.__dict__]:
print(i,sep='\n',file=f)
to_save_dict={'model':best_model.state_dict(),'mlpmodel':best_mlpmodel.state_dict(),'args':args,'labels':ytrue_test,'yhat':yhat_test,'yhat_logit':yhat_logit_test}
torch.save(to_save_dict, os.path.join(model_dir,suffix+'.pth') )
print("done")
if __name__ == '__main__':
"""
build dir
"""
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir,exist_ok=True)
tensorboard_dir=os.path.join(args.save_dir,'log')
if not os.path.exists(tensorboard_dir):
os.makedirs(tensorboard_dir,exist_ok=True)
model_dir=os.path.join(args.save_dir,'model')
if not os.path.exists(model_dir):
os.makedirs(model_dir,exist_ok=True)
info_dir=os.path.join(args.save_dir,'info')
if not os.path.exists(info_dir):
os.makedirs(info_dir,exist_ok=True)
Seed = 0
test_ratio=0.2
print("data splitting Random Seed: ", Seed)
if args.dataset in ['mnist',"mnist_sparse"]:
train_ds,val_ds,test_ds=dataset.get_mnist_dataset(args.data_dir,Seed,test_ratio=test_ratio)
elif args.dataset in ['building']:
train_ds,val_ds,test_ds=dataset.get_building_dataset(args.data_dir,Seed,test_ratio=test_ratio)
elif args.dataset in ['mbuilding']:
train_ds,val_ds,test_ds=dataset.get_mbuilding_dataset(args.data_dir,Seed,test_ratio=test_ratio)
elif args.dataset in ['sbuilding']:
train_ds,val_ds,test_ds=dataset.get_sbuilding_dataset(args.data_dir,Seed,test_ratio=test_ratio)
elif args.dataset in ['smnist']:
train_ds,val_ds,test_ds=dataset.get_smnist_dataset(args.data_dir,Seed,test_ratio=test_ratio)
elif args.dataset in ['dbp']:
train_ds,val_ds,test_ds=dataset.get_dbp_dataset(args.data_dir,Seed,test_ratio=test_ratio)
if args.extent_norm=="T":
train_ds= dataset.affine_transform_to_range(train_ds,target_range=(-1, 1))
val_ds= dataset.affine_transform_to_range(val_ds,target_range=(-1, 1))
test_ds= dataset.affine_transform_to_range(test_ds,target_range=(-1, 1))
train_loader = torch_geometric.loader.DataLoader(train_ds,batch_size=args.train_batch, shuffle=False,pin_memory=True,drop_last=True)
val_loader = torch_geometric.loader.DataLoader(val_ds, batch_size=args.test_batch, shuffle=False, pin_memory=True)
test_loader = torch_geometric.loader.DataLoader(test_ds,batch_size=args.test_batch, shuffle=False,pin_memory=True)
"""
set model seed
"""
Seed = args.man_seed if args.manualSeed else random.randint(1, 10000)
Seed=3407
print("Random Seed: ", Seed)
print(args)
random.seed(Seed)
torch.manual_seed(Seed)
np.random.seed(Seed)
main(args,train_loader,val_loader,test_loader)