Dzy6 commited on
Commit
e551dda
0 Parent(s):
Files changed (8) hide show
  1. HeteroVG_MNIST.py +211 -0
  2. README.md +3 -0
  3. condainstall.txt +6 -0
  4. dataset.py +213 -0
  5. eval.py +261 -0
  6. model_new.py +224 -0
  7. train_new.py +397 -0
  8. util.py +102 -0
HeteroVG_MNIST.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ import torch
8
+ import random
9
+ import pickle as pkl
10
+ from tqdm import tqdm
11
+ from torch import Tensor
12
+ from scipy.spatial import distance_matrix
13
+ import torch_geometric
14
+ from torch_geometric.data import HeteroData
15
+ from torch_geometric.nn import to_hetero
16
+ # from shapely.geometry import Point, Polygon
17
+
18
+
19
+ def cross_product(p1, p2, p3):
20
+ return (p2[0] - p1[0]) * (p3[1] - p1[1]) - (p3[0] - p1[0]) * (p2[1] - p1[1])
21
+
22
+ def colinear(p1, p2, p3):
23
+ if (p1[1]-p2[1])*(p2[0]-p3[0]) == (p1[0]-p2[0])*(p2[1]-p3[1]) and p3[0]>min(p1[0],p2[0]) and p3[0]<max(p1[0],p2[0]): return True
24
+ if (p1[1]-p2[1])*(p2[0]-p3[0]) == (p1[0]-p2[0])*(p2[1]-p3[1]) and p3[1]>min(p1[1],p2[1]) and p3[1]<max(p1[1],p2[1]): return True
25
+
26
+ def is_intersected(p1, p2, p3, p4):
27
+ if colinear(p1, p2, p3) or colinear(p1, p2, p4): return True
28
+ if cross_product(p1, p2, p3) * cross_product(p1, p2, p4) < 0 and cross_product(p3, p4, p1) * cross_product(p3, p4, p2) < 0: return True
29
+ else: return False
30
+
31
+ #Pos of single number
32
+ def read_singe(df, i):
33
+ p_i = df[0][i]
34
+ np_i = list(p_i)
35
+ rflag = 0
36
+ for x in range(len(p_i)-4):
37
+ if p_i[x] == ')': rflag+=1
38
+ if p_i[x:x+4] == '), (' and p_i[x-1]!=')': np_i.insert(x+rflag+1,' 0.0 0.0')
39
+ elif p_i[x:x+4] == '), (' and p_i[x-1]==')': np_i.insert(x+rflag+1,' 1.0 1.0')
40
+ p_i = ''.join(np_i)
41
+
42
+
43
+ pos = np.empty((1,2))
44
+ pi_nums = re.findall(r"\d+\.?\d*",p_i)
45
+ j=0
46
+ while j < len(pi_nums)-1:
47
+ if j == 0:
48
+ pos[0][0] = float(pi_nums[j])
49
+ pos[0][1] = float(pi_nums[j+1])
50
+ j+=2
51
+ continue
52
+ pos = np.append(pos,[[float(pi_nums[j]),0]],0)
53
+ pos[j//2][1] = float(pi_nums[j+1])
54
+ j+=2
55
+ return pos
56
+
57
+ def Visi_Edge(pos_join, flag):
58
+ inside_edge_index = [[],[]]
59
+ apart_edge_index = [[],[]]
60
+
61
+ vg_point = []
62
+ for i in range(len(pos_join)): vg_point.append((pos_join[i][0], pos_join[i][1]))
63
+
64
+ hole_p = np.where(flag==1)[0]
65
+ if len(hole_p) != 0:
66
+ last_id = 0
67
+ for m in range(len(flag)):
68
+ if flag[m] == 2 or flag[m] == 3:
69
+ if sum(flag[last_id:m]) == 0:
70
+ last_id = m+1
71
+ continue
72
+ poly_i = vg_point[last_id:m+1]
73
+ pos_i = np.arange(last_id, m+1)
74
+ last_id = m+1
75
+ for i in range(len(poly_i)):
76
+ if flag[pos_i[i]] == 1:
77
+ for j in range(i, len(flag)):
78
+ if flag[j]==1 or flag[j] == 2 or flag[j] == 3:
79
+ hole_i = poly_i[i:j+1]
80
+ pos_hole = np.arange(i, j+1)
81
+ for p1 in hole_i:
82
+ for p2 in poly_i:
83
+ if p2 not in hole_i:
84
+ inter_count = 0
85
+ for d in range(len(poly_i)-1):
86
+ p3, p4 = poly_i[d], poly_i[d+1]
87
+ if is_intersected(p1, p2, p3, p4): inter_count+=1
88
+ if inter_count==0:
89
+ head, tail = pos_i[poly_i.index(p1)], pos_i[poly_i.index(p2)]
90
+ inside_edge_index[0].append(head), inside_edge_index[1].append(tail)
91
+
92
+ for i in range(len(vg_point)):
93
+ p1 = vg_point[i]
94
+ p1_id = np.count_nonzero(flag[0:i] == 2) + np.count_nonzero(flag[0:i] == 3)
95
+ for j in range(len(vg_point)):
96
+ p2 = vg_point[j]
97
+ if p1 == p2: continue
98
+ p2_id = np.count_nonzero(flag[0:j] == 2) + np.count_nonzero(flag[0:j] == 3)
99
+ inter_count = 0
100
+ for m in range(len(flag-1)):
101
+ if flag[m]!=1 and flag[m]!=2 and flag[m]!=3: p3, p4 = vg_point[m], vg_point[m+1]
102
+ if is_intersected(p1, p2, p3, p4): inter_count+=1
103
+ if inter_count==0:
104
+ head, tail = vg_point.index(p1), vg_point.index(p2)
105
+ cc = np.count_nonzero(flag[min(head, tail):max(head, tail)] == 2) + np.count_nonzero(flag[min(head, tail): max(head, tail)] == 3)
106
+ if p1_id!=p2_id and cc!=0: apart_edge_index[0].append(head), apart_edge_index[1].append(tail)
107
+ #print(i)
108
+
109
+ ninside_edge_index = [[],[]]
110
+ napart_edge_index = [[],[]]
111
+ exteriors = [[],[]]
112
+
113
+ if len(hole_p)!=0:
114
+ for i in range(len(flag)):
115
+ link_i = [pos_join[inside_edge_index[1][j]] for j in range(len(inside_edge_index[1])) if inside_edge_index[0][j]==i]
116
+ if len(link_i)==0: continue
117
+ ninside_edge_index[0].append(i)
118
+ dis_matrix = distance_matrix([pos_join[i]], link_i)
119
+ node_i = (link_i[np.argmin(dis_matrix[0])][0], link_i[np.argmin(dis_matrix[0])][1])
120
+ ninside_edge_index[1].append(vg_point.index(node_i))
121
+
122
+ for i in range(len(vg_point)-1):
123
+ if flag[i]!=1 and flag[i]!=2 and flag[i]!=3 : exteriors[0].append(i), exteriors[1].append(i+1)
124
+
125
+ for i in range(len(flag)):
126
+ link_i = [pos_join[apart_edge_index[1][j]] for j in range(len(apart_edge_index[1])) if apart_edge_index[0][j]==i]
127
+ if len(link_i)==0: continue
128
+ napart_edge_index[0].append(i)
129
+ dis_matrix = distance_matrix([pos_join[i]], link_i)
130
+ node_i = (link_i[np.argmin(dis_matrix[0])][0], link_i[np.argmin(dis_matrix[0])][1])
131
+ napart_edge_index[1].append(vg_point.index(node_i))
132
+
133
+ inside_edge_index, apart_edge_index = ninside_edge_index, napart_edge_index
134
+
135
+ return inside_edge_index, apart_edge_index, exteriors
136
+
137
+
138
+ def HeteroEdge(pos,k):
139
+ pos_join = np.delete(pos, np.where(np.sum(pos, 1)==0)[0], axis=0)
140
+ pos_join = np.delete(pos_join, np.where(np.sum(pos_join, 1)==2)[0], axis=0)
141
+ pos_join = np.delete(pos_join, np.where(np.sum(pos_join, 1)==4)[0], axis=0)
142
+ flag = np.zeros(len(pos_join))
143
+ pos = np.delete(pos, 0, axis=0)
144
+ count, id = 0, 0
145
+ while count<k:
146
+ for i in range(len(pos)):
147
+ if pos[i][0]==0:
148
+ flag[i-1]=1
149
+ pos = np.delete(pos, i, axis=0)
150
+ break
151
+ elif pos[i][0]==1:
152
+ flag[i-1]=2
153
+ pos = np.delete(pos, i, axis=0)
154
+ break
155
+ elif pos[i][0]==2:
156
+ flag[i-1]=3
157
+ pos = np.delete(pos, i, axis=0)
158
+ pos_join[id:i, 0]+=count
159
+ count+=1
160
+ id = i
161
+ break
162
+ pos_join = pos_join
163
+ inside_edge_index, apart_edge_index, exteriors = Visi_Edge(pos_join, flag)
164
+
165
+ return pos_join, inside_edge_index, apart_edge_index, exteriors
166
+
167
+ #build heterovg of k-digit from MNIST
168
+ def NNIST_HeteroVG(df, label_df, k):
169
+ pos = [[0,0]]
170
+ label = ''
171
+ for i in np.random.randint(0, len(df), k):
172
+ while True:
173
+ if len(pos) == 1 and label_df[0][i] == 0: i = random.randint(0, len(df))
174
+ else: break
175
+ pos = np.append(pos, read_singe(df, i), 0)
176
+ pos = np.append(pos, [[2,2]], 0)
177
+ label = label+'%d'%(label_df[0][i])
178
+
179
+ label = int(label)
180
+ pos_join, inside, apart, exteriors = HeteroEdge(pos,k)
181
+
182
+ data = HeteroData()
183
+
184
+ data['vertices'].x = torch.zeros((len(pos_join), 1), dtype=torch.float)
185
+ data.y = torch.tensor(label, dtype=torch.int)
186
+ data.pos = torch.tensor(pos_join, dtype=torch.float)
187
+
188
+ data['vertices', 'inside', 'vertices'].edge_index = torch.tensor([inside[0]+inside[1]+exteriors[0],inside[1]+inside[0]+exteriors[1]], dtype=torch.long)
189
+ data['vertices', 'apart', 'vertices'].edge_index = torch.tensor([apart[0]+apart[1],apart[1]+apart[0]], dtype=torch.long)
190
+ data['vertices', 'inside', 'vertices'].edge_attr = torch.zeros((len(data['vertices', 'inside', 'vertices'].edge_index[0]),1), dtype=torch.float)
191
+ data['vertices', 'apart', 'vertices'].edge_attr = torch.zeros((len(data['vertices', 'apart', 'vertices'].edge_index[0]),1), dtype=torch.float)
192
+
193
+ return data
194
+
195
+ mnist_filename = '/content/drive/MyDrive/MINST_Polygons/polyMNIST/mnist_polygon_test.json'
196
+ label_filename = '/content/drive/MyDrive/MINST_Polygons/polyMNIST/mnist_label_test.json'
197
+ df = pd.read_json(mnist_filename)
198
+ label_df = pd.read_json(label_filename)
199
+
200
+ K = 2 # number of digits
201
+ N = 10 # number of generated graphs
202
+ multi_mnist_dataset = []
203
+ for k in range(2, K+1):
204
+ for i in tqdm(range(N)):
205
+ data = NNIST_HeteroVG(df, label_df, k=k)
206
+ multi_mnist_dataset.append(data)
207
+
208
+ if not os.path.exists('/content/drive/MyDrive/MINST_Polygons/multi_mnist'):
209
+ os.makedirs('/content/drive/MyDrive/MINST_Polygons/multi_mnist')
210
+ with open('/content/drive/MyDrive/MINST_Polygons/multi_mnist/multi_mnist.pkl','wb') as file:
211
+ pkl.dump(multi_mnist_dataset, file)
README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # KDD24 PolygonGNN: Representation Learning for Polygonal Geometries with Heterogeneous Visibility Graph
2
+ data is on [dropbox](https://www.dropbox.com/scl/fo/f7dir04pldz36n6m47m30/ABxnZk8Qyf16k0Yo75WqXpY?rlkey=f3lhgyv7um323ngpa2bmueimq&st=e4wg0uec&dl=0)
3
+
condainstall.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ conda create -n graph python=3.8
2
+ conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia
3
+ conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
4
+ conda install pyg -c pyg
5
+ conda install -c conda-forge pytorch_sparse
6
+ conda install matplotlib numpy ipykernel pandas tensorboard
dataset.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+ import os
3
+ import os.path
4
+ import torch
5
+ import numpy as np
6
+ import pandas as pd
7
+ import sys
8
+ import pickle
9
+ import time
10
+ import torchvision.datasets as datasets
11
+ import torchvision.transforms as transforms
12
+ from PIL import Image
13
+ from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
14
+ from torchvision.datasets import VisionDataset
15
+ from torch.utils.data import Dataset
16
+ from datetime import date, timedelta,datetime
17
+ import random
18
+ import pickle as pkl
19
+ import string
20
+
21
+ valid_chars = 'EFHILOTUYZ'
22
+
23
+ alphabetic_labels = [char1 + char2 for char1 in valid_chars for char2 in valid_chars]
24
+ alphabetic_labels.sort()
25
+ label_mapping = {label: idx for idx, label in enumerate(alphabetic_labels)} # to number
26
+ reverse_label_mapping = {v: k for k, v in label_mapping.items()} # to alphabetic
27
+
28
+ single_alphabetic_labels=[char1 for char1 in valid_chars]
29
+ single_alphabetic_labels.sort()
30
+ single_label_mapping = {label: idx for idx, label in enumerate(single_alphabetic_labels)}
31
+ single_reverse_label_mapping = {v: k for k, v in single_label_mapping.items()}
32
+
33
+ def get_mnist_dataset(data_dir='data/multi_mnist.pkl',Seed=0,test_ratio=0.2):
34
+
35
+ random.seed(Seed)
36
+ torch.manual_seed(Seed)
37
+ np.random.seed(Seed)
38
+
39
+ with open(data_dir, 'rb') as f:
40
+ dataset = pkl.load(f)
41
+ for entry in dataset:
42
+ entry.y -= 10
43
+
44
+ np.random.shuffle(dataset)
45
+ val_test_split = int(np.around( test_ratio * len(dataset) ))
46
+ train_val_split = int(len(dataset)-2*val_test_split)
47
+ train_ds = dataset[:train_val_split]
48
+ val_ds = dataset[train_val_split:train_val_split+val_test_split]
49
+ test_ds = dataset[train_val_split+val_test_split:]
50
+
51
+ print(data_dir)
52
+ print('Train: ' +str(len(train_ds)))
53
+ print('Val : ' +str(len(val_ds)))
54
+ print('Test : ' +str(len(test_ds)))
55
+
56
+ return train_ds,val_ds,test_ds
57
+
58
+ def get_building_dataset(data_dir='data/building_with_index.pkl',Seed=0,test_ratio=0.2):
59
+
60
+ random.seed(Seed)
61
+ torch.manual_seed(Seed)
62
+ np.random.seed(Seed)
63
+
64
+ with open(data_dir, 'rb') as f:
65
+ dataset = pkl.load(f)
66
+ for entry in dataset:
67
+ entry.y = label_mapping[entry.y]
68
+
69
+ np.random.shuffle(dataset)
70
+ val_test_split = int(np.around( test_ratio * len(dataset) ))
71
+ train_val_split = int(len(dataset)-2*val_test_split)
72
+ train_ds = dataset[:train_val_split]
73
+ val_ds = dataset[train_val_split:train_val_split+val_test_split]
74
+ test_ds = dataset[train_val_split+val_test_split:]
75
+
76
+ print(data_dir)
77
+ print('Train: ' +str(len(train_ds)))
78
+ print('Val : ' +str(len(val_ds)))
79
+ print('Test : ' +str(len(test_ds)))
80
+
81
+ return train_ds,val_ds,test_ds
82
+
83
+ def get_mbuilding_dataset(data_dir='data/mp_building.pkl',Seed=0,test_ratio=0.2):
84
+
85
+ random.seed(Seed)
86
+ torch.manual_seed(Seed)
87
+ np.random.seed(Seed)
88
+
89
+ with open(data_dir, 'rb') as f:
90
+ dataset = pkl.load(f)
91
+ for entry in dataset:
92
+ entry.y = label_mapping[entry.y]
93
+
94
+ np.random.shuffle(dataset)
95
+ val_test_split = int(np.around( test_ratio * len(dataset) ))
96
+ train_val_split = int(len(dataset)-2*val_test_split)
97
+ train_ds = dataset[:train_val_split]
98
+ val_ds = dataset[train_val_split:train_val_split+val_test_split]
99
+ test_ds = dataset[train_val_split+val_test_split:]
100
+
101
+ print(data_dir)
102
+ print('Train: ' +str(len(train_ds)))
103
+ print('Val : ' +str(len(val_ds)))
104
+ print('Test : ' +str(len(test_ds)))
105
+
106
+ return train_ds,val_ds,test_ds
107
+
108
+ def get_sbuilding_dataset(data_dir='data/single_building.pkl',Seed=0,test_ratio=0.2):
109
+
110
+ random.seed(Seed)
111
+ torch.manual_seed(Seed)
112
+ np.random.seed(Seed)
113
+
114
+ with open(data_dir, 'rb') as f:
115
+ dataset = pkl.load(f)
116
+ for entry in dataset:
117
+ entry.y = single_label_mapping[entry.y]
118
+
119
+ np.random.shuffle(dataset)
120
+ val_test_split = int(np.around( test_ratio * len(dataset) ))
121
+ train_val_split = int(len(dataset)-2*val_test_split)
122
+ train_ds = dataset[:train_val_split]
123
+ val_ds = dataset[train_val_split:train_val_split+val_test_split]
124
+ test_ds = dataset[train_val_split+val_test_split:]
125
+
126
+ print(data_dir)
127
+ print('Train: ' +str(len(train_ds)))
128
+ print('Val : ' +str(len(val_ds)))
129
+ print('Test : ' +str(len(test_ds)))
130
+
131
+ return train_ds,val_ds,test_ds
132
+
133
+ def get_smnist_dataset(data_dir='data/single_mnist.pkl',Seed=0,test_ratio=0.2):
134
+
135
+ random.seed(Seed)
136
+ torch.manual_seed(Seed)
137
+ np.random.seed(Seed)
138
+
139
+ with open(data_dir, 'rb') as f:
140
+ dataset = pkl.load(f)
141
+
142
+ np.random.shuffle(dataset)
143
+ val_test_split = int(np.around( test_ratio * len(dataset) ))
144
+ train_val_split = int(len(dataset)-2*val_test_split)
145
+ train_ds = dataset[:train_val_split]
146
+ val_ds = dataset[train_val_split:train_val_split+val_test_split]
147
+ test_ds = dataset[train_val_split+val_test_split:]
148
+
149
+ print(data_dir)
150
+ print('Train: ' +str(len(train_ds)))
151
+ print('Val : ' +str(len(val_ds)))
152
+ print('Test : ' +str(len(test_ds)))
153
+
154
+ return train_ds,val_ds,test_ds
155
+
156
+ def get_dbp_dataset(data_dir='data/triple_building.pkl',Seed=0,test_ratio=0.2):
157
+
158
+ random.seed(Seed)
159
+ torch.manual_seed(Seed)
160
+ np.random.seed(Seed)
161
+
162
+ with open(data_dir, 'rb') as f:
163
+ dataset = pkl.load(f)
164
+ for entry in dataset:
165
+ entry.y = 1 if entry.y>=1 else 0
166
+
167
+ np.random.shuffle(dataset)
168
+ val_test_split = int(np.around( test_ratio * len(dataset) ))
169
+ train_val_split = int(len(dataset)-2*val_test_split)
170
+ train_ds = dataset[:train_val_split]
171
+ val_ds = dataset[train_val_split:train_val_split+val_test_split]
172
+ test_ds = dataset[train_val_split+val_test_split:]
173
+
174
+ print(data_dir)
175
+ print('Train: ' +str(len(train_ds)))
176
+ print('Val : ' +str(len(val_ds)))
177
+ print('Test : ' +str(len(test_ds)))
178
+
179
+ return train_ds,val_ds,test_ds
180
+
181
+ def affine_transform_to_range(ds, target_range=(-1, 1)):
182
+ # Find the extent (min and max) of coordinates in both x and y directions
183
+ for item in ds:
184
+ min_x = torch.min(item.pos[:,0])
185
+ min_y = torch.min(item.pos[:,1])
186
+
187
+ max_x = torch.max(item.pos[:,0])
188
+ max_y = torch.max(item.pos[:,1])
189
+
190
+ scale_x = (target_range[1] - target_range[0]) / (max_x - min_x)
191
+ scale_y = (target_range[1] - target_range[0]) / (max_y - min_y)
192
+ translate_x = target_range[0] - min_x * scale_x
193
+ translate_y = target_range[0] - min_y * scale_y
194
+
195
+ # Apply the affine transformation to
196
+ item.pos[:,0] = item.pos[:,0] * scale_x + translate_x
197
+ item.pos[:,1] = item.pos[:,1] * scale_y + translate_y
198
+ return ds
199
+
200
+ class CustomDataset(Dataset):
201
+ def __init__(self, data_list):
202
+ super(CustomDataset, self).__init__()
203
+ self.data_list = data_list
204
+
205
+ def len(self):
206
+ return len(self.data_list)
207
+
208
+ def get(self, idx):
209
+ return self.data_list[idx]
210
+
211
+ if __name__ == '__main__':
212
+ a,b,c=get_mnist_dataset()
213
+ print("")
eval.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import torch
5
+ import pandas as pd
6
+ import numpy as np
7
+ import time
8
+ import torch.optim as optim
9
+ import scipy
10
+
11
+ from matplotlib import cm
12
+ import matplotlib.pyplot as plt
13
+ import json
14
+ import torch.nn.functional as F
15
+ from torch.nn.functional import softmax
16
+
17
+ torch.autograd.set_detect_anomaly(True)
18
+ import pickle
19
+ from torch.utils.tensorboard import SummaryWriter
20
+ import dataset,util
21
+ from model_new import Smodel
22
+ import model_new
23
+
24
+
25
+ import torch.nn as nn
26
+ import torchvision.transforms as transforms
27
+ import torchvision.datasets
28
+ import torchvision.models
29
+ import math
30
+ import shutil
31
+ import time
32
+ from datetime import date, timedelta,datetime
33
+ import torch_geometric
34
+ from torch_geometric.data import Data, DataLoader
35
+ from torch_geometric.nn import MessagePassing
36
+ from torch_geometric.utils import add_self_loops
37
+ from torch_geometric.nn import GIN,GATConv,MLP
38
+ from torch_geometric.nn.pool import global_mean_pool,global_add_pool
39
+ import csv
40
+
41
+ blue = lambda x: '\033[94m' + x + '\033[0m'
42
+ red = lambda x: '\033[31m' + x + '\033[0m'
43
+ green = lambda x: '\033[32m' + x + '\033[0m'
44
+ yellow = lambda x: '\033[33m' + x + '\033[0m'
45
+ greenline = lambda x: '\033[42m' + x + '\033[0m'
46
+ yellowline = lambda x: '\033[43m' + x + '\033[0m'
47
+
48
+ def get_args():
49
+ parser = argparse.ArgumentParser()
50
+ parser.add_argument('--model',default="our", type=str)
51
+ parser.add_argument('--train_batch', default=64, type=int)
52
+ parser.add_argument('--test_batch', default=128, type=int)
53
+ parser.add_argument('--share', type=str, default="0")
54
+ parser.add_argument('--edge_rep', type=str, default="True")
55
+ parser.add_argument('--batchnorm', type=str, default="True")
56
+ parser.add_argument('--extent_norm', type=str, default="T")
57
+ parser.add_argument('--spanning_tree', type=str, default="F")
58
+
59
+ parser.add_argument('--loss_coef', default=0.1, type=float)
60
+ parser.add_argument('--h_ch', default=512, type=int)
61
+ parser.add_argument('--localdepth', type=int, default=1)
62
+ parser.add_argument('--num_interactions', type=int, default=4)
63
+ parser.add_argument('--finaldepth', type=int, default=4)
64
+ parser.add_argument('--classifier_depth', type=int, default=4)
65
+ parser.add_argument('--dropout', type=float, default=0.0)
66
+
67
+ parser.add_argument('--dataset', type=str, default='mnist')
68
+ parser.add_argument('--log', type=str, default="True")
69
+ parser.add_argument('--test_per_round', type=int, default=10)
70
+ parser.add_argument('--patience', type=int, default=30) #scheduler
71
+ parser.add_argument('--nepoch', type=int, default=201)
72
+ parser.add_argument('--lr', type=float, default=1e-4)
73
+ parser.add_argument('--manualSeed', type=str, default="False")
74
+ parser.add_argument('--man_seed', type=int, default=12345)
75
+
76
+ parser.add_argument("--targetfiles", nargs='+', type=str, default=["Dec11-14:44:32.pth","Nov13-14:30:48.pth"])
77
+ args = parser.parse_args()
78
+ args.log=True if args.log=="True" else False
79
+ args.edge_rep=True if args.edge_rep=="True" else False
80
+ args.batchnorm=True if args.batchnorm=="True" else False
81
+ args.save_dir=os.path.join('./save/',args.dataset)
82
+ args.manualSeed=True if args.manualSeed=="True" else False
83
+ return args
84
+
85
+ args = get_args()
86
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
87
+ criterion=nn.CrossEntropyLoss()
88
+
89
+ def forward_HGT(args,data,model,mlpmodel):
90
+ data = data.to(device)
91
+ x,batch=data.pos, data['vertices'].batch
92
+ data["vertices"]['x']=data.pos
93
+ label=data.y.long().view(-1)
94
+
95
+ output=model(data.x_dict, data.edge_index_dict)
96
+ if args.dataset in ["dbp"]:
97
+ graph_embeddings=global_add_pool(output,batch)
98
+ else:
99
+ graph_embeddings=global_add_pool(output,batch)
100
+ graph_embeddings.clamp_(max=1e6)
101
+
102
+ output=mlpmodel(graph_embeddings)
103
+ # log_probs = F.log_softmax(output, dim=1)
104
+
105
+ loss = criterion(output, label)
106
+ return loss,output,label, graph_embeddings
107
+
108
+ def forward(args,data,model,mlpmodel):
109
+ data = data.to(device)
110
+ edge_index1=data['vertices', 'inside', 'vertices']['edge_index']
111
+ edge_index2=data['vertices', 'apart', 'vertices']['edge_index']
112
+ combined_edge_index=torch.cat([data['vertices', 'inside', 'vertices']['edge_index'],data['vertices', 'apart', 'vertices']['edge_index']],1)
113
+
114
+ if args.spanning_tree == 'True':
115
+ edge_weight=torch.rand(combined_edge_index.shape[1]) + 1
116
+ combined_edge_index = util.build_spanning_tree_edge(combined_edge_index, edge_weight,num_nodes=num_nodes,)
117
+
118
+ num_edge_inside=edge_index1.shape[1]
119
+ x,batch=data.pos, data['vertices'].batch
120
+ label=data.y.long().view(-1)
121
+ """
122
+ triplets are not the same for graphs when training
123
+ """
124
+ num_nodes=x.shape[0]
125
+ edge_index_2rd, num_triplets_real, edx_jk, edx_ij = util.triplets(combined_edge_index, num_nodes)
126
+
127
+ input_feature=torch.zeros([x.shape[0],args.h_ch],device=device)
128
+ output=model(input_feature,x,[edge_index1,edge_index2], edge_index_2rd,edx_jk, edx_ij,batch,num_edge_inside,args.edge_rep)
129
+ output=torch.cat(output,dim=1)
130
+ graph_embeddings=global_add_pool(output,batch)
131
+ graph_embeddings.clamp_(max=1e6)
132
+
133
+ output=mlpmodel(graph_embeddings)
134
+ # log_probs = F.log_softmax(output, dim=1)
135
+
136
+ loss = criterion(output, label)
137
+ return loss,output,label,graph_embeddings
138
+ def test(args,loader,model,mlpmodel,writer,reverse_mapping ):
139
+ y_hat, y_true,y_hat_logit = [], [], [],
140
+ embeddings=[]
141
+
142
+ loss_total, pred_num = 0, 0
143
+ model.eval()
144
+ mlpmodel.eval()
145
+ with torch.no_grad():
146
+ for data in loader:
147
+ if args.model=="our":
148
+ loss,output,label,embedding =forward(args,data,model,mlpmodel)
149
+ elif args.model in ["HGT","HAN"]:
150
+ loss,output,label,embedding =forward_HGT(args,data,model,mlpmodel)
151
+ _, pred = output.topk(1, dim=1, largest=True, sorted=True)
152
+ pred,label,output=pred.cpu(),label.cpu(),output.cpu()
153
+ y_hat += list(pred.detach().numpy().reshape(-1))
154
+ y_true += list(label.detach().numpy().reshape(-1))
155
+ y_hat_logit+=list(output.detach().numpy())
156
+ embeddings.append(embedding)
157
+
158
+ pred_num += len(label.reshape(-1, 1))
159
+ loss_total += loss.detach() * len(label.reshape(-1, 1))
160
+
161
+ y_true_str=[reverse_mapping(item) for item in y_true]
162
+ writer.add_embedding(torch.cat(embeddings,dim=0).detach().cpu(),metadata=y_true_str,tag="numbers")
163
+ writer.close()
164
+ return loss_total/pred_num,y_hat, y_true, y_hat_logit
165
+
166
+ def main(args,train_Loader,val_Loader,test_Loader):
167
+ donefiles=os.listdir(os.path.join(args.save_dir,args.model,'model'))
168
+ tensorboard_dir=os.path.join(args.save_dir,args.model,'log')
169
+ if args.dataset in ["mnist","mnist_sparse"]:
170
+ reverse_mapping=lambda x: x + 10
171
+ # list(map(lambda x: x - 10, []))
172
+ elif args.dataset in ["building","mbuilding"]:
173
+ reverse_mapping=lambda x: dataset.reverse_label_mapping[x]
174
+ elif args.dataset in ["sbuilding"]:
175
+ reverse_mapping=lambda x: dataset.single_reverse_label_mapping[x]
176
+ elif args.dataset in ["dbp","smnist"]:
177
+ reverse_mapping=lambda x: x
178
+ for file in donefiles:
179
+ if file not in args.targetfiles:
180
+ continue
181
+ else:
182
+ print(file)
183
+ saved_dict=torch.load(os.path.join(args.save_dir,args.model,'model',file))
184
+ if saved_dict['args'].dataset in ["mnist","mnist_sparse"]:
185
+ x_out=90
186
+ elif saved_dict['args'].dataset in ["building","mbuilding"]:
187
+ x_out=100
188
+ elif saved_dict['args'].dataset in ["sbuilding","smnist"]:
189
+ x_out=10
190
+ elif saved_dict['args'].dataset in ["dbp"]:
191
+ x_out=2
192
+ if saved_dict['args'].model=="our":
193
+ model=Smodel(h_channel=saved_dict['args'].h_ch,input_featuresize=saved_dict['args'].h_ch,\
194
+ localdepth=saved_dict['args'].localdepth,num_interactions=saved_dict['args'].num_interactions,finaldepth=saved_dict['args'].finaldepth,share=saved_dict['args'].share,batchnorm=saved_dict['args'].batchnorm)
195
+ mlpmodel=MLP(in_channels=saved_dict['args'].h_ch*saved_dict['args'].num_interactions, hidden_channels=saved_dict['args'].h_ch,out_channels=x_out, num_layers=saved_dict['args'].classifier_depth)
196
+ elif saved_dict['args'].model=="HGT":
197
+ model=model_new.HGT(hidden_channels=saved_dict['args'].h_ch, out_channels=saved_dict['args'].h_ch, num_heads=2, num_layers=saved_dict['args'].num_interactions)
198
+ mlpmodel=MLP(in_channels=saved_dict['args'].h_ch, hidden_channels=saved_dict['args'].h_ch,out_channels=x_out, num_layers=saved_dict['args'].classifier_depth,dropout=saved_dict['args'].dropout)
199
+ elif saved_dict['args'].model=="HAN":
200
+ model=model_new.HAN(hidden_channels=saved_dict['args'].h_ch, out_channels=saved_dict['args'].h_ch, num_heads=2, num_layers=saved_dict['args'].num_interactions)
201
+ mlpmodel=MLP(in_channels=saved_dict['args'].h_ch, hidden_channels=saved_dict['args'].h_ch,out_channels=x_out, num_layers=saved_dict['args'].classifier_depth,dropout=saved_dict['args'].dropout)
202
+ model.to(device), mlpmodel.to(device)
203
+ try:
204
+ model.load_state_dict(saved_dict['model'],strict=True)
205
+ mlpmodel.load_state_dict(saved_dict['mlpmodel'],strict=True)
206
+ except OSError:
207
+ print('loadfail: ',file)
208
+ pass
209
+ print(saved_dict['args'])
210
+
211
+ writer = SummaryWriter(os.path.join(tensorboard_dir,file+"_embedding"))
212
+ test_loss, yhat_test, ytrue_test, yhatlogit_test = test(saved_dict['args'],test_Loader,model,mlpmodel,writer,reverse_mapping)
213
+
214
+ pred_dir=os.path.join(tensorboard_dir,file+"_test_record")
215
+ to_save_dict={'labels':ytrue_test,'yhat':yhat_test,'yhat_logit':yhatlogit_test}
216
+ torch.save(to_save_dict, pred_dir)
217
+
218
+ test_acc=util.calculate(yhat_test,ytrue_test,yhatlogit_test)
219
+ util.print_1(0,'Test', {"loss":test_loss,"acc":test_acc},color=blue)
220
+
221
+
222
+ if __name__ == '__main__':
223
+ Seed = 0
224
+ test_ratio=0.2
225
+ print("data splitting Random Seed: ", Seed)
226
+ if args.dataset in ["mnist"]:
227
+ args.data_dir='data/multi_mnist_with_index.pkl'
228
+ train_ds,val_ds,test_ds=dataset.get_mnist_dataset(args.data_dir,Seed,test_ratio=test_ratio)
229
+ elif args.dataset in ["mnist_sparse"]:
230
+ args.data_dir='data/multi_mnist_sparse.pkl'
231
+ train_ds,val_ds,test_ds=dataset.get_mnist_dataset(args.data_dir,Seed,test_ratio=test_ratio)
232
+ elif args.dataset in ["building"]:
233
+ args.data_dir='data/building_with_index.pkl'
234
+ train_ds,val_ds,test_ds=dataset.get_mbuilding_dataset(args.data_dir,Seed,test_ratio=test_ratio)
235
+ elif args.dataset in ["mbuilding"]:
236
+ args.data_dir='data/mp_building.pkl'
237
+ train_ds,val_ds,test_ds=dataset.get_building_dataset(args.data_dir,Seed,test_ratio=test_ratio)
238
+ elif args.dataset in ["sbuilding"]:
239
+ args.data_dir='data/single_building.pkl'
240
+ train_ds,val_ds,test_ds=dataset.get_sbuilding_dataset(args.data_dir,Seed,test_ratio=test_ratio)
241
+ elif args.dataset in ["smnist"]:
242
+ args.data_dir='data/single_mnist.pkl'
243
+ train_ds,val_ds,test_ds=dataset.get_smnist_dataset(args.data_dir,Seed,test_ratio=test_ratio)
244
+ elif args.dataset in ['dbp']:
245
+ args.data_dir='data/triple_building_600.pkl'
246
+ train_ds,val_ds,test_ds=dataset.get_dbp_dataset(args.data_dir,Seed,test_ratio=test_ratio)
247
+
248
+ if args.extent_norm=="T":
249
+ train_ds= dataset.affine_transform_to_range(train_ds,target_range=(-1, 1))
250
+ val_ds= dataset.affine_transform_to_range(val_ds,target_range=(-1, 1))
251
+ test_ds= dataset.affine_transform_to_range(test_ds,target_range=(-1, 1))
252
+ train_loader = torch_geometric.loader.DataLoader(train_ds,batch_size=args.train_batch, shuffle=False,pin_memory=True)
253
+ val_loader = torch_geometric.loader.DataLoader(val_ds, batch_size=args.test_batch, shuffle=False, pin_memory=True)
254
+ test_loader = torch_geometric.loader.DataLoader(test_ds,batch_size=args.test_batch, shuffle=False,pin_memory=True)
255
+
256
+ Seed=random.randint(1, 10000)
257
+ print("Random Seed: ", Seed)
258
+ random.seed(Seed)
259
+ torch.manual_seed(Seed)
260
+ np.random.seed(Seed)
261
+ main(args,train_loader,val_loader,test_loader)
model_new.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from math import pi as PI
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.nn.parallel
9
+ import torch.utils.data
10
+ import torch_geometric.transforms as T
11
+ from torch.nn import ModuleList, Parameter
12
+ from torch_geometric.nn import HANConv, HEATConv, HGTConv, Linear
13
+ from torch_geometric.nn.conv import MessagePassing
14
+ from torch_geometric.nn.dense.linear import Linear
15
+ # from dataset import
16
+ from torch_geometric.nn.inits import glorot, zeros
17
+ from torch_geometric.utils import softmax
18
+ from torch_scatter import scatter
19
+
20
+ from util import get_angle, get_theta, triplets
21
+
22
+ class Smodel(nn.Module):
23
+ def __init__(self, h_channel=16,input_featuresize=32,localdepth=2,num_interactions=3,finaldepth=3,share='0',batchnorm="True"):
24
+ super(Smodel,self).__init__()
25
+ self.training=True
26
+ self.h_channel = h_channel
27
+ self.input_featuresize=input_featuresize
28
+ self.localdepth = localdepth
29
+ self.num_interactions=num_interactions
30
+ self.finaldepth=finaldepth
31
+ self.batchnorm = batchnorm
32
+ self.activation=nn.ReLU()
33
+ self.att = Parameter(torch.ones(4),requires_grad=True)
34
+
35
+ num_gaussians=(1,1,1)
36
+ self.mlp_geo = ModuleList()
37
+ for i in range(self.localdepth):
38
+ if i == 0:
39
+ self.mlp_geo.append(Linear(sum(num_gaussians), h_channel))
40
+ else:
41
+ self.mlp_geo.append(Linear(h_channel, h_channel))
42
+ if self.batchnorm == "True":
43
+ self.mlp_geo.append(nn.BatchNorm1d(h_channel))
44
+ self.mlp_geo.append(self.activation)
45
+
46
+ self.mlp_geo_backup = ModuleList()
47
+ for i in range(self.localdepth):
48
+ if i == 0:
49
+ self.mlp_geo_backup.append(Linear(4, h_channel))
50
+ else:
51
+ self.mlp_geo_backup.append(Linear(h_channel, h_channel))
52
+ if self.batchnorm == "True":
53
+ self.mlp_geo_backup.append(nn.BatchNorm1d(h_channel))
54
+ self.mlp_geo_backup.append(self.activation)
55
+ self.translinear=Linear(input_featuresize+1, self.h_channel)
56
+ self.interactions= ModuleList()
57
+ for i in range(self.num_interactions):
58
+ block = SPNN(
59
+ in_ch=self.input_featuresize,
60
+ hidden_channels=self.h_channel,
61
+ activation=self.activation,
62
+ finaldepth=self.finaldepth,
63
+ batchnorm=self.batchnorm,
64
+ num_input_geofeature=self.h_channel
65
+ )
66
+ self.interactions.append(block)
67
+ self.reset_parameters()
68
+ def reset_parameters(self):
69
+ for lin in self.mlp_geo:
70
+ if isinstance(lin, Linear):
71
+ torch.nn.init.xavier_uniform_(lin.weight)
72
+ lin.bias.data.fill_(0)
73
+ for i in (self.interactions):
74
+ i.reset_parameters()
75
+
76
+ def single_forward(self, input_feature,coords,edge_index,edge_index_2rd, edx_jk, edx_ij,batch,num_edge_inside,edge_rep):
77
+ if edge_rep:
78
+ i, j, k = edge_index_2rd
79
+ edge_index1,edge_index2= edge_index
80
+ edge_index_all=torch.cat([edge_index1,edge_index2],1)
81
+ distance_ij=(coords[j] - coords[i]).norm(p=2, dim=1)
82
+ distance_jk=(coords[j] - coords[k]).norm(p=2, dim=1)
83
+ theta_ijk = get_angle(coords[j] - coords[i], coords[k] - coords[j])
84
+ geo_encoding_1st=distance_ij[:,None]
85
+ geo_encoding=torch.cat([geo_encoding_1st,distance_jk[:,None],theta_ijk[:,None]],dim=-1)
86
+ else:
87
+ coords_j = coords[edge_index[0]]
88
+ coords_i = coords[edge_index[1]]
89
+ geo_encoding=torch.cat([coords_j,coords_i],dim=-1)
90
+ if edge_rep:
91
+ for lin in self.mlp_geo:
92
+ geo_encoding=lin(geo_encoding)
93
+ else:
94
+ for lin in self.mlp_geo_backup:
95
+ geo_encoding=lin(geo_encoding)
96
+ geo_encoding=torch.zeros_like(geo_encoding,device=geo_encoding.device,dtype=geo_encoding.dtype)
97
+ node_feature= input_feature
98
+ node_feature_list=[]
99
+ for interaction in self.interactions:
100
+ node_feature = interaction(node_feature,geo_encoding,edge_index_2rd,edx_jk,edx_ij,num_edge_inside,self.att)
101
+ node_feature_list.append(node_feature)
102
+ return node_feature_list
103
+ def forward(self, input_feature, coords,edge_index,edge_index_2rd, edx_jk, edx_ij,batch,num_edge_inside,edge_rep):
104
+ output=self.single_forward(input_feature,coords,edge_index,edge_index_2rd, edx_jk, edx_ij,batch,num_edge_inside,edge_rep)
105
+ return output
106
+
107
+ class SPNN(torch.nn.Module):
108
+ def __init__(
109
+ self,
110
+ in_ch,
111
+ hidden_channels,
112
+ activation=torch.nn.ReLU(),
113
+ finaldepth=3,
114
+ batchnorm="True",
115
+ num_input_geofeature=13
116
+ ):
117
+ super(SPNN, self).__init__()
118
+ self.activation = activation
119
+ self.finaldepth = finaldepth
120
+ self.batchnorm = batchnorm
121
+ self.num_input_geofeature=num_input_geofeature
122
+
123
+ self.WMLP_list = ModuleList()
124
+ for _ in range(4):
125
+ WMLP = ModuleList()
126
+ for i in range(self.finaldepth + 1):
127
+ if i == 0:
128
+ WMLP.append(Linear(hidden_channels*3+num_input_geofeature, hidden_channels))
129
+ else:
130
+ WMLP.append(Linear(hidden_channels, hidden_channels))
131
+ if self.batchnorm == "True":
132
+ WMLP.append(nn.BatchNorm1d(hidden_channels))
133
+ WMLP.append(self.activation)
134
+ self.WMLP_list.append(WMLP)
135
+ self.reset_parameters()
136
+
137
+ def reset_parameters(self):
138
+ for mlp in self.WMLP_list:
139
+ for lin in mlp:
140
+ if isinstance(lin, Linear):
141
+ torch.nn.init.xavier_uniform_(lin.weight)
142
+ lin.bias.data.fill_(0)
143
+ def forward(self, node_feature,geo_encoding,edge_index_2rd,edx_jk,edx_ij,num_edge_inside,att):
144
+ i,j,k = edge_index_2rd
145
+ if node_feature is None:
146
+ concatenated_vector = geo_encoding
147
+ else:
148
+ node_attr_0st = node_feature[i]
149
+ node_attr_1st = node_feature[j]
150
+ node_attr_2 = node_feature[k]
151
+ concatenated_vector = torch.cat(
152
+ [
153
+ node_attr_0st,
154
+ node_attr_1st,node_attr_2,
155
+ geo_encoding,
156
+ ],
157
+ dim=-1,
158
+ )
159
+ x_i = concatenated_vector
160
+
161
+ edge1_edge1_mask = (edx_ij < num_edge_inside) & (edx_jk < num_edge_inside)
162
+ edge1_edge2_mask = (edx_ij < num_edge_inside) & (edx_jk >= num_edge_inside)
163
+ edge2_edge1_mask = (edx_ij >= num_edge_inside) & (edx_jk < num_edge_inside)
164
+ edge2_edge2_mask = (edx_ij >= num_edge_inside) & (edx_jk >= num_edge_inside)
165
+ masks=[edge1_edge1_mask,edge1_edge2_mask,edge2_edge1_mask,edge2_edge2_mask]
166
+
167
+ x_output=torch.zeros(x_i.shape[0],self.WMLP_list[0][0].weight.shape[0],device=x_i.device)
168
+ for index in range(4):
169
+ WMLP=self.WMLP_list[index]
170
+ x=x_i[masks[index]]
171
+ for lin in WMLP:
172
+ x=lin(x)
173
+ x = F.leaky_relu(x)*att[index]
174
+ x_output[masks[index]]+=x
175
+
176
+ out_feature = scatter(x_output, i, dim=0, reduce='add')
177
+ return out_feature
178
+
179
+ class HGT(torch.nn.Module):
180
+ def __init__(self, hidden_channels, out_channels, num_heads, num_layers):
181
+ super().__init__()
182
+
183
+ self.lin_dict = torch.nn.ModuleDict()
184
+ for node_type in ["vertices"]:
185
+ self.lin_dict[node_type] = Linear(-1, hidden_channels)
186
+
187
+ self.convs = torch.nn.ModuleList()
188
+ for _ in range(num_layers):
189
+ conv = HGTConv(hidden_channels, hidden_channels, (['vertices'],[('vertices', 'inside', 'vertices'), ('vertices', 'apart', 'vertices')]),
190
+ num_heads, group='sum')
191
+ self.convs.append(conv)
192
+
193
+ self.lin = Linear(hidden_channels, out_channels)
194
+
195
+ def forward(self, x_dict, edge_index_dict):
196
+ for node_type, x in x_dict.items():
197
+ x_dict[node_type]=self.lin_dict[node_type](x).relu_()
198
+
199
+ for conv in self.convs:
200
+ x_dict = conv(x_dict, edge_index_dict)
201
+ return self.lin(x_dict['vertices'])
202
+ class HAN(torch.nn.Module):
203
+ def __init__(self, hidden_channels, out_channels, num_heads, num_layers):
204
+ super().__init__()
205
+
206
+ self.lin_dict = torch.nn.ModuleDict()
207
+ for node_type in ["vertices"]:
208
+ self.lin_dict[node_type] = Linear(-1, hidden_channels)
209
+
210
+ self.convs = torch.nn.ModuleList()
211
+ for _ in range(num_layers):
212
+ conv = HANConv(hidden_channels, hidden_channels, (['vertices'],[('vertices', 'inside', 'vertices'), ('vertices', 'apart', 'vertices')]),
213
+ num_heads)
214
+ self.convs.append(conv)
215
+
216
+ self.lin = Linear(hidden_channels, out_channels)
217
+
218
+ def forward(self, x_dict, edge_index_dict):
219
+ for node_type, x in x_dict.items():
220
+ x_dict[node_type]=self.lin_dict[node_type](x).relu_()
221
+
222
+ for conv in self.convs:
223
+ x_dict = conv(x_dict, edge_index_dict)
224
+ return self.lin(x_dict['vertices'])
train_new.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import json
4
+ import os
5
+ import random
6
+ import time
7
+
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import pandas as pd
11
+ import scipy
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import torch.optim as optim
15
+ from matplotlib import cm
16
+ from sklearn.metrics import (auc, explained_variance_score, f1_score,
17
+ mean_absolute_error, mean_squared_error,
18
+ precision_score, r2_score, recall_score,
19
+ roc_auc_score, roc_curve)
20
+ from torch.nn.functional import softmax
21
+ from torch_geometric.utils import subgraph
22
+
23
+ torch.autograd.set_detect_anomaly(True)
24
+ import math
25
+ import pickle
26
+ import time
27
+ from datetime import date, datetime, timedelta
28
+
29
+ import torch.nn as nn
30
+ import torch_geometric
31
+ import torchvision.datasets
32
+ import torchvision.models
33
+ import torchvision.transforms as transforms
34
+ from torch.utils.tensorboard import SummaryWriter
35
+ from torch_geometric.nn import GIN, MLP, GATConv
36
+ from torch_geometric.nn.pool import global_add_pool, global_mean_pool
37
+ from torch_geometric.utils import add_self_loops
38
+
39
+ import dataset
40
+ import model_new
41
+ import util
42
+ from dataset import label_mapping, reverse_label_mapping
43
+ from model_new import Smodel
44
+
45
+ blue = lambda x: '\033[94m' + x + '\033[0m'
46
+ red = lambda x: '\033[31m' + x + '\033[0m'
47
+ green = lambda x: '\033[32m' + x + '\033[0m'
48
+ yellow = lambda x: '\033[33m' + x + '\033[0m'
49
+ greenline = lambda x: '\033[42m' + x + '\033[0m'
50
+ yellowline = lambda x: '\033[43m' + x + '\033[0m'
51
+
52
+ def get_args():
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument('--model',default="our", type=str)
55
+ parser.add_argument('--train_batch', default=64, type=int)
56
+ parser.add_argument('--test_batch', default=128, type=int)
57
+ parser.add_argument('--share', type=str, default="0")
58
+ parser.add_argument('--edge_rep', type=str, default="True")
59
+ parser.add_argument('--batchnorm', type=str, default="True")
60
+ parser.add_argument('--extent_norm', type=str, default="T")
61
+ parser.add_argument('--spanning_tree', type=str, default="T")
62
+
63
+ parser.add_argument('--loss_coef', default=0.1, type=float)
64
+ parser.add_argument('--h_ch', default=512, type=int)
65
+ parser.add_argument('--localdepth', type=int, default=1)
66
+ parser.add_argument('--num_interactions', type=int, default=4)
67
+ parser.add_argument('--finaldepth', type=int, default=4)
68
+ parser.add_argument('--classifier_depth', type=int, default=4)
69
+ parser.add_argument('--dropout', type=float, default=0.0)
70
+
71
+ parser.add_argument('--dataset', type=str, default='mnist')
72
+ parser.add_argument('--log', type=str, default="True")
73
+ parser.add_argument('--test_per_round', type=int, default=10)
74
+ parser.add_argument('--patience', type=int, default=30) #scheduler
75
+ parser.add_argument('--nepoch', type=int, default=301)
76
+ parser.add_argument('--lr', type=float, default=1e-4)
77
+ parser.add_argument('--manualSeed', type=str, default="False")
78
+ parser.add_argument('--man_seed', type=int, default=12345)
79
+ args = parser.parse_args()
80
+ args.log=True if args.log=="True" else False
81
+ args.edge_rep=True if args.edge_rep=="True" else False
82
+ args.batchnorm=True if args.batchnorm=="True" else False
83
+ args.save_dir=os.path.join('./save/',args.dataset,args.model)
84
+ args.manualSeed=True if args.manualSeed=="True" else False
85
+ return args
86
+
87
+ args = get_args()
88
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
89
+ criterion=nn.CrossEntropyLoss()
90
+ if args.dataset in ["mnist"]:
91
+ x_out=90
92
+ args.data_dir='data/multi_mnist_with_index.pkl'
93
+ elif args.dataset in ["mnist_sparse"]:
94
+ x_out=90
95
+ args.data_dir='data/multi_mnist_sparse.pkl'
96
+ elif args.dataset in ["building"]:
97
+ x_out=100
98
+ args.data_dir='data/building_with_index.pkl'
99
+ elif args.dataset in ["mbuilding"]:
100
+ x_out=100
101
+ args.data_dir='data/mp_building.pkl'
102
+ elif args.dataset in ["sbuilding"]:
103
+ x_out=10
104
+ args.data_dir='data/single_building.pkl'
105
+ elif args.dataset in ["smnist"]:
106
+ x_out=10
107
+ args.data_dir='data/single_mnist.pkl'
108
+ elif args.dataset in ["dbp"]:
109
+ x_out=2
110
+ args.data_dir='data/triple_building_600.pkl'
111
+
112
+
113
+ if args.model=="our":
114
+ model=Smodel(h_channel=args.h_ch,input_featuresize=args.h_ch,\
115
+ localdepth=args.localdepth,num_interactions=args.num_interactions,finaldepth=args.finaldepth,share=args.share,batchnorm=args.batchnorm)
116
+ mlpmodel=MLP(in_channels=args.h_ch*args.num_interactions, hidden_channels=args.h_ch,out_channels=x_out, num_layers=args.classifier_depth,dropout=args.dropout)
117
+
118
+ elif args.model=="HGT":
119
+ model=model_new.HGT(hidden_channels=args.h_ch, out_channels=args.h_ch, num_heads=2, num_layers=args.num_interactions)
120
+ mlpmodel=MLP(in_channels=args.h_ch, hidden_channels=args.h_ch,out_channels=x_out, num_layers=args.classifier_depth,dropout=args.dropout)
121
+ elif args.model=="HAN":
122
+ model=model_new.HAN(hidden_channels=args.h_ch, out_channels=args.h_ch, num_heads=2, num_layers=args.num_interactions)
123
+ mlpmodel=MLP(in_channels=args.h_ch, hidden_channels=args.h_ch,out_channels=x_out, num_layers=args.classifier_depth,dropout=args.dropout)
124
+
125
+ model.to(device), mlpmodel.to(device)
126
+ opt_list=list(model.parameters())+list(mlpmodel.parameters())
127
+
128
+ optimizer = torch.optim.Adam( opt_list, lr=args.lr)
129
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=args.patience, min_lr=1e-8)
130
+
131
+ def contrastive_loss(embeddings,labels,margin):
132
+
133
+ positive_mask = labels.view(-1, 1) == labels.view(1, -1)
134
+ negative_mask = ~positive_mask
135
+
136
+ # Calculate the number of positive and negative pairs
137
+ num_positive_pairs = positive_mask.sum() - labels.shape[0]
138
+ num_negative_pairs = negative_mask.sum()
139
+
140
+ # If there are no negative pairs, return a placeholder loss
141
+ if num_negative_pairs==0 or num_positive_pairs== 0:
142
+ print("all pos or neg")
143
+ return torch.tensor(0, dtype=torch.float)
144
+ # Calculate the pairwise Euclidean distances between embeddings
145
+ distances = torch.cdist(embeddings, embeddings)/np.sqrt(embeddings.shape[1])
146
+
147
+ if num_positive_pairs>num_negative_pairs:
148
+ # Sample an equal number of + pairs
149
+ positive_indices = torch.nonzero(positive_mask)
150
+ random_positive_indices = torch.randperm(len(positive_indices))[:num_negative_pairs]
151
+ selected_positive_indices = positive_indices[random_positive_indices]
152
+
153
+ # Select corresponding negative pairs
154
+ negative_mask.fill_diagonal_(False)
155
+ negative_distances = distances[negative_mask].view(-1, 1)
156
+ positive_distances = distances[selected_positive_indices[:,0],selected_positive_indices[:,1]].view(-1, 1)
157
+ else: # case for most datasets
158
+ # Sample an equal number of - pairs
159
+ negative_indices = torch.nonzero(negative_mask)
160
+ random_negative_indices = torch.randperm(len(negative_indices))[:num_positive_pairs]
161
+ selected_negative_indices = negative_indices[random_negative_indices]
162
+
163
+ # Select corresponding positive pairs
164
+ positive_mask.fill_diagonal_(False)
165
+ positive_distances = distances[positive_mask].view(-1, 1)
166
+ negative_distances = distances[selected_negative_indices[:,0],selected_negative_indices[:,1]].view(-1, 1)
167
+
168
+ # Calculate the loss for positive and negative pairs
169
+ loss = (positive_distances - negative_distances + margin).clamp(min=0).mean()
170
+ return loss
171
+
172
+ def forward_HGT(data,model,mlpmodel):
173
+ data = data.to(device)
174
+ x,batch=data.pos, data['vertices'].batch
175
+ data["vertices"]['x']=data.pos
176
+ label=data.y.long().view(-1)
177
+
178
+ optimizer.zero_grad()
179
+
180
+ output=model(data.x_dict, data.edge_index_dict)
181
+ if args.dataset in ["dbp"]:
182
+ graph_embeddings=global_add_pool(output,batch)
183
+ else:
184
+ graph_embeddings=global_add_pool(output,batch)
185
+ graph_embeddings.clamp_(max=1e6)
186
+ c_loss=contrastive_loss(graph_embeddings,label,margin=1)
187
+ output=mlpmodel(graph_embeddings)
188
+ # log_probs = F.log_softmax(output, dim=1)
189
+
190
+ loss = criterion(output, label)
191
+ loss+=c_loss*args.loss_coef
192
+ return loss,c_loss*args.loss_coef,output,label
193
+
194
+ def forward(data,model,mlpmodel):
195
+ data = data.to(device)
196
+ edge_index1=data['vertices', 'inside', 'vertices']['edge_index']
197
+ edge_index2=data['vertices', 'apart', 'vertices']['edge_index']
198
+ combined_edge_index=torch.cat([data['vertices', 'inside', 'vertices']['edge_index'],data['vertices', 'apart', 'vertices']['edge_index']],1)
199
+ num_edge_inside=edge_index1.shape[1]
200
+
201
+ if args.spanning_tree == 'T':
202
+ edge_weight=torch.rand(combined_edge_index.shape[1]) + 1
203
+ undirected_spanning_edge = util.build_spanning_tree_edge(combined_edge_index, edge_weight,num_nodes=data.pos.shape[0])
204
+
205
+ edge_set_1 = set(map(tuple, edge_index2.t().tolist()))
206
+ edge_set_2 = set(map(tuple, undirected_spanning_edge.t().tolist()))
207
+
208
+ common_edges = edge_set_1.intersection(edge_set_2)
209
+ common_edges_tensor = torch.tensor(list(common_edges), dtype=torch.long).t().to(device)
210
+ spanning_edge=torch.cat([edge_index1,common_edges_tensor],1)
211
+ combined_edge_index=spanning_edge
212
+ x,batch=data.pos, data['vertices'].batch
213
+ label=data.y.long().view(-1)
214
+
215
+ num_nodes=x.shape[0]
216
+ edge_index_2rd, num_triplets_real, edx_jk, edx_ij = util.triplets(combined_edge_index, num_nodes)
217
+ optimizer.zero_grad()
218
+ input_feature=torch.zeros([x.shape[0],args.h_ch],device=device)
219
+ output=model(input_feature,x,[edge_index1,edge_index2], edge_index_2rd,edx_jk, edx_ij,batch,num_edge_inside,args.edge_rep)
220
+ output=torch.cat(output,dim=1)
221
+ if args.dataset in ["dbp"]:
222
+ graph_embeddings=global_add_pool(output,batch)
223
+ else:
224
+ graph_embeddings=global_add_pool(output,batch)
225
+ graph_embeddings.clamp_(max=1e6)
226
+ c_loss=contrastive_loss(graph_embeddings,label,margin=1)
227
+ output=mlpmodel(graph_embeddings)
228
+
229
+ loss = criterion(output, label)
230
+ loss+=c_loss*args.loss_coef
231
+ return loss,c_loss*args.loss_coef,output,label
232
+ def train(train_Loader,model,mlpmodel ):
233
+ epochloss=0
234
+ epochcloss=0
235
+ y_hat, y_true,y_hat_logit = [], [], []
236
+ optimizer.zero_grad()
237
+ model.train()
238
+ mlpmodel.train()
239
+ for i,data in enumerate(train_Loader):
240
+ if args.model=="our":
241
+ loss,c_loss,output,label =forward(data,model,mlpmodel)
242
+ elif args.model in ["HGT","HAN"]:
243
+ loss,c_loss,output,label =forward_HGT(data,model,mlpmodel)
244
+
245
+ loss.backward()
246
+ optimizer.step()
247
+ epochloss+=loss.detach().cpu()
248
+ epochcloss+=c_loss.detach().cpu()
249
+
250
+ _, pred = output.topk(1, dim=1, largest=True, sorted=True)
251
+ pred,label,output=pred.cpu(),label.cpu(),output.cpu()
252
+ y_hat += list(pred.detach().numpy().reshape(-1))
253
+ y_true += list(label.detach().numpy().reshape(-1))
254
+ y_hat_logit+=list(output.detach().numpy())
255
+ return epochloss.item()/len(train_Loader),epochcloss.item()/len(train_Loader),y_hat, y_true,y_hat_logit
256
+
257
+ def test(loader,model,mlpmodel ):
258
+ y_hat, y_true,y_hat_logit = [], [], []
259
+ loss_total, pred_num = 0, 0
260
+ model.eval()
261
+ mlpmodel.eval()
262
+ with torch.no_grad():
263
+ for data in loader:
264
+ if args.model=="our":
265
+ loss,c_loss,output,label =forward(data,model,mlpmodel)
266
+ elif args.model in ["HGT","HAN"]:
267
+ loss,c_loss,output,label =forward_HGT(data,model,mlpmodel)
268
+
269
+ _, pred = output.topk(1, dim=1, largest=True, sorted=True)
270
+ pred,label,output=pred.cpu(),label.cpu(),output.cpu()
271
+ y_hat += list(pred.detach().numpy().reshape(-1))
272
+ y_true += list(label.detach().numpy().reshape(-1))
273
+ y_hat_logit+=list(output.detach().numpy())
274
+
275
+ pred_num += len(label.reshape(-1, 1))
276
+ loss_total += loss.detach() * len(label.reshape(-1, 1))
277
+ return loss_total/pred_num,y_hat, y_true, y_hat_logit
278
+ def main(args,train_Loader,val_Loader,test_Loader):
279
+ best_val_trigger = -1
280
+ old_lr=1e3
281
+ suffix="{}{}-{}:{}:{}".format(datetime.now().strftime("%h"),
282
+ datetime.now().strftime("%d"),
283
+ datetime.now().strftime("%H"),
284
+ datetime.now().strftime("%M"),
285
+ datetime.now().strftime("%S"))
286
+ if args.log: writer = SummaryWriter(os.path.join(tensorboard_dir,suffix))
287
+
288
+ for epoch in range(args.nepoch):
289
+ train_loss,train_closs,y_hat, y_true,y_hat_logit=train(train_Loader,model,mlpmodel )
290
+
291
+ train_acc=util.calculate(y_hat,y_true,y_hat_logit)
292
+ try:util.record({"loss":train_loss,"closs":train_closs,"acc":train_acc},epoch,writer,"Train")
293
+ except: pass
294
+ util.print_1(epoch,'Train',{"loss":train_loss,"closs":train_closs,"acc":train_acc})
295
+ if epoch % args.test_per_round == 0:
296
+ val_loss, yhat_val, ytrue_val, yhatlogit_val = test(val_Loader,model,mlpmodel )
297
+ test_loss, yhat_test, ytrue_test, yhatlogit_test = test(test_Loader,model,mlpmodel )
298
+ val_acc=util.calculate(yhat_val,ytrue_val,yhatlogit_val)
299
+ try:util.record({"loss":val_loss,"acc":val_acc},epoch,writer,"Val")
300
+ except: pass
301
+ util.print_1(epoch,'Val',{"loss":val_loss,"acc":val_acc},color=blue)
302
+ test_acc=util.calculate(yhat_test,ytrue_test,yhatlogit_test)
303
+ try:util.record({"loss":test_loss,"acc":test_acc},epoch,writer,"Test")
304
+ except: pass
305
+ util.print_1(epoch,'Test',{"loss":test_loss,"acc":test_acc},color=blue)
306
+ val_trigger=val_acc
307
+ if val_trigger > best_val_trigger:
308
+ best_val_trigger = val_trigger
309
+ best_model = copy.deepcopy(model)
310
+ best_mlpmodel=copy.deepcopy(mlpmodel)
311
+ best_info=[epoch,val_trigger]
312
+ """
313
+ update lr when epoch≥30
314
+ """
315
+ if epoch >= 30:
316
+ lr = scheduler.optimizer.param_groups[0]['lr']
317
+ if old_lr!=lr:
318
+ print(red('lr'), epoch, (lr), sep=', ')
319
+ old_lr=lr
320
+ scheduler.step(val_trigger)
321
+ """
322
+ use best model to get best model result
323
+ """
324
+ val_loss, yhat_val, ytrue_val, yhat_logit_val = test(val_Loader,best_model,best_mlpmodel)
325
+ test_loss, yhat_test, ytrue_test, yhat_logit_test= test(test_Loader,best_model,best_mlpmodel)
326
+
327
+ val_acc=util.calculate(yhat_val,ytrue_val,yhat_logit_val)
328
+ util.print_1(best_info[0],'BestVal',{"loss":val_loss,"acc":val_acc},color=blue)
329
+ test_acc=util.calculate(yhat_test,ytrue_test,yhat_logit_test)
330
+ util.print_1(best_info[0],'BestTest',{"loss":test_loss,"acc":test_acc},color=blue)
331
+ if args.model=="our":print(best_model.att)
332
+
333
+ """
334
+ save training info and best result
335
+ """
336
+ result_file=os.path.join(info_dir, suffix)
337
+ with open(result_file, 'w') as f:
338
+ print("Random Seed: ", Seed,file=f)
339
+ print(f"acc val : {val_acc:.3f}, Test : {test_acc:.3f}", file=f)
340
+ print(f"Best info: {best_info}", file=f)
341
+ for i in [[a,getattr(args, a)] for a in args.__dict__]:
342
+ print(i,sep='\n',file=f)
343
+ to_save_dict={'model':best_model.state_dict(),'mlpmodel':best_mlpmodel.state_dict(),'args':args,'labels':ytrue_test,'yhat':yhat_test,'yhat_logit':yhat_logit_test}
344
+ torch.save(to_save_dict, os.path.join(model_dir,suffix+'.pth') )
345
+ print("done")
346
+
347
+ if __name__ == '__main__':
348
+ """
349
+ build dir
350
+ """
351
+ if not os.path.exists(args.save_dir):
352
+ os.makedirs(args.save_dir,exist_ok=True)
353
+ tensorboard_dir=os.path.join(args.save_dir,'log')
354
+ if not os.path.exists(tensorboard_dir):
355
+ os.makedirs(tensorboard_dir,exist_ok=True)
356
+ model_dir=os.path.join(args.save_dir,'model')
357
+ if not os.path.exists(model_dir):
358
+ os.makedirs(model_dir,exist_ok=True)
359
+ info_dir=os.path.join(args.save_dir,'info')
360
+ if not os.path.exists(info_dir):
361
+ os.makedirs(info_dir,exist_ok=True)
362
+
363
+ Seed = 0
364
+ test_ratio=0.2
365
+ print("data splitting Random Seed: ", Seed)
366
+ if args.dataset in ['mnist',"mnist_sparse"]:
367
+ train_ds,val_ds,test_ds=dataset.get_mnist_dataset(args.data_dir,Seed,test_ratio=test_ratio)
368
+ elif args.dataset in ['building']:
369
+ train_ds,val_ds,test_ds=dataset.get_building_dataset(args.data_dir,Seed,test_ratio=test_ratio)
370
+ elif args.dataset in ['mbuilding']:
371
+ train_ds,val_ds,test_ds=dataset.get_mbuilding_dataset(args.data_dir,Seed,test_ratio=test_ratio)
372
+ elif args.dataset in ['sbuilding']:
373
+ train_ds,val_ds,test_ds=dataset.get_sbuilding_dataset(args.data_dir,Seed,test_ratio=test_ratio)
374
+ elif args.dataset in ['smnist']:
375
+ train_ds,val_ds,test_ds=dataset.get_smnist_dataset(args.data_dir,Seed,test_ratio=test_ratio)
376
+ elif args.dataset in ['dbp']:
377
+ train_ds,val_ds,test_ds=dataset.get_dbp_dataset(args.data_dir,Seed,test_ratio=test_ratio)
378
+ if args.extent_norm=="T":
379
+ train_ds= dataset.affine_transform_to_range(train_ds,target_range=(-1, 1))
380
+ val_ds= dataset.affine_transform_to_range(val_ds,target_range=(-1, 1))
381
+ test_ds= dataset.affine_transform_to_range(test_ds,target_range=(-1, 1))
382
+
383
+ train_loader = torch_geometric.loader.DataLoader(train_ds,batch_size=args.train_batch, shuffle=False,pin_memory=True,drop_last=True)
384
+ val_loader = torch_geometric.loader.DataLoader(val_ds, batch_size=args.test_batch, shuffle=False, pin_memory=True)
385
+ test_loader = torch_geometric.loader.DataLoader(test_ds,batch_size=args.test_batch, shuffle=False,pin_memory=True)
386
+ """
387
+ set model seed
388
+ """
389
+ Seed = args.man_seed if args.manualSeed else random.randint(1, 10000)
390
+ Seed=3407
391
+ print("Random Seed: ", Seed)
392
+ print(args)
393
+ random.seed(Seed)
394
+ torch.manual_seed(Seed)
395
+ np.random.seed(Seed)
396
+ main(args,train_loader,val_loader,test_loader)
397
+
util.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ from matplotlib import cm
6
+ import matplotlib.pyplot as plt
7
+ import scipy
8
+ import torch.nn.functional as F
9
+ import torchvision
10
+
11
+ from sklearn.metrics import explained_variance_score,mean_squared_error,mean_absolute_error,r2_score,precision_score,recall_score,f1_score,roc_auc_score,roc_curve, auc,confusion_matrix
12
+ from sklearn.feature_selection import r_regression
13
+
14
+ from torch_sparse import SparseTensor
15
+ from scipy.sparse import csr_matrix
16
+ from scipy.sparse.csgraph import minimum_spanning_tree
17
+ from math import pi as PI
18
+
19
+ def scipy_spanning_tree(edge_index, edge_weight,num_nodes ):
20
+ row, col = edge_index.cpu()
21
+ edge_weight=edge_weight.cpu()
22
+ cgraph = csr_matrix((edge_weight, (row, col)), shape=(num_nodes, num_nodes))
23
+ Tcsr = minimum_spanning_tree(cgraph)
24
+ tree_row, tree_col = Tcsr.nonzero()
25
+ spanning_edges = np.stack([tree_row,tree_col],0)
26
+ return spanning_edges
27
+
28
+ def build_spanning_tree_edge(edge_index,edge_weight, num_nodes):
29
+ spanning_edges = scipy_spanning_tree(edge_index, edge_weight,num_nodes,)
30
+
31
+ spanning_edges = torch.tensor(spanning_edges, dtype=torch.long, device=edge_index.device)
32
+ spanning_edges_undirected = torch.cat([spanning_edges,torch.stack([spanning_edges[1],spanning_edges[0]])],1)
33
+ return spanning_edges_undirected
34
+
35
+
36
+
37
+
38
+ def record(values,epoch,writer,phase="Train"):
39
+ """ tfboard write """
40
+ for key,value in values.items():
41
+ writer.add_scalar(key+"/"+phase,value,epoch)
42
+ def calculate(y_hat,y_true,y_hat_logit):
43
+ """ calculate five metrics using y_hat, y_true, y_hat_logit """
44
+ train_acc=(np.array(y_hat) == np.array(y_true)).sum()/len(y_true)
45
+ # recall=recall_score(y_true, y_hat,zero_division=0,average='micro')
46
+ # precision=precision_score(y_true, y_hat,zero_division=0,average='micro')
47
+ # Fscore=f1_score(y_true, y_hat,zero_division=0,average='micro')
48
+ # roc=roc_auc_score(y_true, scipy.special.softmax(np.array(y_hat_logit),axis=1)[:,1],average='micro',multi_class='ovr')
49
+ # one_hot_encoded_labels = np.zeros((len(y_true), 100))
50
+ # one_hot_encoded_labels[np.arange(len(y_true)), y_true] = 1
51
+ # roc=roc_auc_score(one_hot_encoded_labels, scipy.special.softmax(np.array(y_hat_logit),axis=1),average='micro',multi_class='ovr')
52
+ return train_acc
53
+
54
+
55
+ def print_1(epoch,phase,values,color=None):
56
+ """ print epoch info"""
57
+ if color is not None:
58
+ print(color( f"epoch[{epoch:d}] {phase}"+ " ".join([f"{key}={value:.3f}" for key, value in values.items()]) ))
59
+ else:
60
+ print(( f"epoch[{epoch:d}] {phase}"+ " ".join([f"{key}={value:.3f}" for key, value in values.items()]) ))
61
+
62
+ def get_angle(v1, v2):
63
+ if v1.shape[1]==2:
64
+ v1=F.pad(v1, (0, 1),value=0)
65
+ if v2.shape[1]==2:
66
+ v2= F.pad(v2, (0, 1),value=0)
67
+ return torch.atan2( torch.cross(v1, v2, dim=1).norm(p=2, dim=1), (v1 * v2).sum(dim=1))
68
+ def get_theta(v1, v2):
69
+ # v1 is starting line, right-hand rule to v2, if thumb is up, +, else -
70
+ angle=get_angle(v1, v2)
71
+ if v1.shape[1]==2:
72
+ v1=F.pad(v1, (0, 1),value=0)
73
+ if v2.shape[1]==2:
74
+ v2= F.pad(v2, (0, 1),value=0)
75
+ v = torch.cross(v1, v2, dim=1)[...,2]
76
+ flag = torch.sign((v))
77
+ flag[flag==0]=-1
78
+ return angle*flag
79
+
80
+ def triplets(edge_index, num_nodes):
81
+ row, col = edge_index
82
+
83
+ value = torch.arange(row.size(0), device=row.device)
84
+ adj_t = SparseTensor(row=row, col=col, value=value,
85
+ sparse_sizes=(num_nodes, num_nodes))
86
+ adj_t_col = adj_t[:,row]
87
+ num_triplets = adj_t_col.set_value(None).sum(dim=0).to(torch.long)
88
+
89
+ idx_j = row.repeat_interleave(num_triplets)
90
+ idx_i = col.repeat_interleave(num_triplets)
91
+ edx_2nd = value.repeat_interleave(num_triplets)
92
+ idx_k = adj_t_col.t().storage.col()
93
+ edx_1st = adj_t_col.t().storage.value()
94
+ mask1 = (idx_i == idx_k) & (idx_j != idx_i) # Remove go back triplets.
95
+ mask2 = (idx_i == idx_j) & (idx_j != idx_k) # Remove repeat self loop triplets
96
+ mask3 = (idx_j == idx_k) & (idx_i != idx_k) # Remove self-loop neighbors
97
+ mask = ~(mask1 | mask2 | mask3)
98
+ idx_i, idx_j, idx_k, edx_1st, edx_2nd = idx_i[mask], idx_j[mask], idx_k[mask], edx_1st[mask], edx_2nd[mask]
99
+
100
+ num_triplets_real = torch.cumsum(num_triplets, dim=0) - torch.cumsum(~mask, dim=0)[torch.cumsum(num_triplets, dim=0)-1]
101
+
102
+ return torch.stack([idx_i, idx_j, idx_k]), num_triplets_real.to(torch.long), edx_1st, edx_2nd