|
import torch |
|
import pandas as pd |
|
import numpy as np |
|
|
|
from matplotlib import cm |
|
import matplotlib.pyplot as plt |
|
import scipy |
|
import torch.nn.functional as F |
|
import torchvision |
|
|
|
from sklearn.metrics import explained_variance_score,mean_squared_error,mean_absolute_error,r2_score,precision_score,recall_score,f1_score,roc_auc_score,roc_curve, auc,confusion_matrix |
|
from sklearn.feature_selection import r_regression |
|
|
|
from torch_sparse import SparseTensor |
|
from scipy.sparse import csr_matrix |
|
from scipy.sparse.csgraph import minimum_spanning_tree |
|
from math import pi as PI |
|
|
|
def scipy_spanning_tree(edge_index, edge_weight,num_nodes ): |
|
row, col = edge_index.cpu() |
|
edge_weight=edge_weight.cpu() |
|
cgraph = csr_matrix((edge_weight, (row, col)), shape=(num_nodes, num_nodes)) |
|
Tcsr = minimum_spanning_tree(cgraph) |
|
tree_row, tree_col = Tcsr.nonzero() |
|
spanning_edges = np.stack([tree_row,tree_col],0) |
|
return spanning_edges |
|
|
|
def build_spanning_tree_edge(edge_index,edge_weight, num_nodes): |
|
spanning_edges = scipy_spanning_tree(edge_index, edge_weight,num_nodes,) |
|
|
|
spanning_edges = torch.tensor(spanning_edges, dtype=torch.long, device=edge_index.device) |
|
spanning_edges_undirected = torch.cat([spanning_edges,torch.stack([spanning_edges[1],spanning_edges[0]])],1) |
|
return spanning_edges_undirected |
|
|
|
|
|
|
|
|
|
def record(values,epoch,writer,phase="Train"): |
|
""" tfboard write """ |
|
for key,value in values.items(): |
|
writer.add_scalar(key+"/"+phase,value,epoch) |
|
def calculate(y_hat,y_true,y_hat_logit): |
|
""" calculate five metrics using y_hat, y_true, y_hat_logit """ |
|
train_acc=(np.array(y_hat) == np.array(y_true)).sum()/len(y_true) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return train_acc |
|
|
|
|
|
def print_1(epoch,phase,values,color=None): |
|
""" print epoch info""" |
|
if color is not None: |
|
print(color( f"epoch[{epoch:d}] {phase}"+ " ".join([f"{key}={value:.3f}" for key, value in values.items()]) )) |
|
else: |
|
print(( f"epoch[{epoch:d}] {phase}"+ " ".join([f"{key}={value:.3f}" for key, value in values.items()]) )) |
|
|
|
def get_angle(v1, v2): |
|
if v1.shape[1]==2: |
|
v1=F.pad(v1, (0, 1),value=0) |
|
if v2.shape[1]==2: |
|
v2= F.pad(v2, (0, 1),value=0) |
|
return torch.atan2( torch.cross(v1, v2, dim=1).norm(p=2, dim=1), (v1 * v2).sum(dim=1)) |
|
def get_theta(v1, v2): |
|
|
|
angle=get_angle(v1, v2) |
|
if v1.shape[1]==2: |
|
v1=F.pad(v1, (0, 1),value=0) |
|
if v2.shape[1]==2: |
|
v2= F.pad(v2, (0, 1),value=0) |
|
v = torch.cross(v1, v2, dim=1)[...,2] |
|
flag = torch.sign((v)) |
|
flag[flag==0]=-1 |
|
return angle*flag |
|
|
|
def triplets(edge_index, num_nodes): |
|
row, col = edge_index |
|
|
|
value = torch.arange(row.size(0), device=row.device) |
|
adj_t = SparseTensor(row=row, col=col, value=value, |
|
sparse_sizes=(num_nodes, num_nodes)) |
|
adj_t_col = adj_t[:,row] |
|
num_triplets = adj_t_col.set_value(None).sum(dim=0).to(torch.long) |
|
|
|
idx_j = row.repeat_interleave(num_triplets) |
|
idx_i = col.repeat_interleave(num_triplets) |
|
edx_2nd = value.repeat_interleave(num_triplets) |
|
idx_k = adj_t_col.t().storage.col() |
|
edx_1st = adj_t_col.t().storage.value() |
|
mask1 = (idx_i == idx_k) & (idx_j != idx_i) |
|
mask2 = (idx_i == idx_j) & (idx_j != idx_k) |
|
mask3 = (idx_j == idx_k) & (idx_i != idx_k) |
|
mask = ~(mask1 | mask2 | mask3) |
|
idx_i, idx_j, idx_k, edx_1st, edx_2nd = idx_i[mask], idx_j[mask], idx_k[mask], edx_1st[mask], edx_2nd[mask] |
|
|
|
num_triplets_real = torch.cumsum(num_triplets, dim=0) - torch.cumsum(~mask, dim=0)[torch.cumsum(num_triplets, dim=0)-1] |
|
|
|
return torch.stack([idx_i, idx_j, idx_k]), num_triplets_real.to(torch.long), edx_1st, edx_2nd |
|
|