chris1nexus commited on
Commit
d60982d
1 Parent(s): 9f4ac91

First commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +4 -4
  2. app.py +145 -0
  3. feature_extractor/.ipynb_checkpoints/weights_check-checkpoint.ipynb +0 -0
  4. feature_extractor/__init__.py +0 -0
  5. feature_extractor/__pycache__/__init__.cpython-38.pyc +0 -0
  6. feature_extractor/__pycache__/build_graph_utils.cpython-38.pyc +0 -0
  7. feature_extractor/__pycache__/build_graphs.cpython-38.pyc +0 -0
  8. feature_extractor/__pycache__/cl.cpython-38.pyc +0 -0
  9. feature_extractor/__pycache__/simclr.cpython-36.pyc +0 -0
  10. feature_extractor/__pycache__/simclr.cpython-38.pyc +0 -0
  11. feature_extractor/build_graph_utils.py +85 -0
  12. feature_extractor/build_graphs.py +114 -0
  13. feature_extractor/cl.py +83 -0
  14. feature_extractor/config.yaml +23 -0
  15. feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-36.pyc +0 -0
  16. feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-38.pyc +0 -0
  17. feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-36.pyc +0 -0
  18. feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-38.pyc +0 -0
  19. feature_extractor/data_aug/dataset_wrapper.py +93 -0
  20. feature_extractor/data_aug/gaussian_blur.py +26 -0
  21. feature_extractor/load_patches.py +37 -0
  22. feature_extractor/loss/__pycache__/nt_xent.cpython-36.pyc +0 -0
  23. feature_extractor/loss/__pycache__/nt_xent.cpython-38.pyc +0 -0
  24. feature_extractor/loss/nt_xent.py +65 -0
  25. feature_extractor/models/__init__.py +0 -0
  26. feature_extractor/models/__pycache__/__init__.cpython-38.pyc +0 -0
  27. feature_extractor/models/__pycache__/resnet_simclr.cpython-36.pyc +0 -0
  28. feature_extractor/models/__pycache__/resnet_simclr.cpython-38.pyc +0 -0
  29. feature_extractor/models/baseline_encoder.py +43 -0
  30. feature_extractor/models/resnet_simclr.py +37 -0
  31. feature_extractor/run.py +21 -0
  32. feature_extractor/simclr.py +165 -0
  33. feature_extractor/viewer.py +227 -0
  34. helper.py +104 -0
  35. main.py +169 -0
  36. metadata/label_map.pkl +3 -0
  37. models/.gitkeep +1 -0
  38. models/GraphTransformer.py +123 -0
  39. models/ViT.py +415 -0
  40. models/__init__.py +0 -0
  41. models/__pycache__/GraphTransformer.cpython-38.pyc +0 -0
  42. models/__pycache__/ViT.cpython-38.pyc +0 -0
  43. models/__pycache__/__init__.cpython-38.pyc +0 -0
  44. models/__pycache__/gcn.cpython-38.pyc +0 -0
  45. models/__pycache__/layers.cpython-38.pyc +0 -0
  46. models/__pycache__/weight_init.cpython-38.pyc +0 -0
  47. models/gcn.py +420 -0
  48. models/layers.py +280 -0
  49. models/weight_init.py +78 -0
  50. option.py +41 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: AioMedica
3
- emoji: 👁
4
- colorFrom: yellow
5
- colorTo: indigo
6
  sdk: streamlit
7
  sdk_version: 1.10.0
8
  app_file: app.py
 
1
  ---
2
+ title: MedFormer
3
+ emoji: 🏃
4
+ colorFrom: purple
5
+ colorTo: yellow
6
  sdk: streamlit
7
  sdk_version: 1.10.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import openslide
3
+ import os
4
+ from streamlit_option_menu import option_menu
5
+ import torch
6
+
7
+
8
+ if torch.cuda.is_available():
9
+ os.system("pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html")
10
+ os.system("pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html")
11
+ os.system("pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html")
12
+ else:
13
+ os.system("pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
14
+ os.system("pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
15
+ os.system("pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
16
+
17
+ from predict import Predictor
18
+
19
+
20
+
21
+ # environment variables for the inference api
22
+ os.environ['DATA_DIR'] = 'queries'
23
+ os.environ['PATCHES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'patches')
24
+ os.environ['SLIDES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'slides')
25
+ os.environ['GRAPHCAM_DIR'] = os.path.join(os.environ['DATA_DIR'], 'graphcam_plots')
26
+ os.makedirs(os.environ['GRAPHCAM_DIR'], exist_ok=True)
27
+
28
+
29
+ # manually put the metadata in the metadata folder
30
+ os.environ['CLASS_METADATA'] ='metadata/label_map.pkl'
31
+
32
+ # manually put the desired weights in the weights folder
33
+ os.environ['WEIGHTS_PATH'] = WEIGHTS_PATH='weights'
34
+ os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'feature_extractor', 'model.pth')
35
+ os.environ['GT_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'graph_transformer', 'GraphCAM.pth')
36
+
37
+
38
+ st.set_page_config(page_title="",layout='wide')
39
+ predictor = Predictor()
40
+
41
+
42
+
43
+
44
+
45
+ ABOUT_TEXT = "🤗 LastMinute Medical - Web diagnosis tool."
46
+ CONTACT_TEXT = """
47
+ _Built by Christian Cancedda and LabLab lads with love_ ❤️
48
+ [![Follow](https://img.shields.io/github/followers/Chris1nexus?style=social)](https://github.com/Chris1nexus)
49
+ [![Follow](https://img.shields.io/twitter/follow/chris_cancedda?style=social)](https://twitter.com/intent/follow?screen_name=chris_cancedda)
50
+ """
51
+ VISUALIZE_TEXT = "Visualize WSI slide by uploading it on the provided window"
52
+ DETECT_TEXT = "Generate a preliminary diagnosis about the presence of pulmonary disease"
53
+
54
+
55
+
56
+ with st.sidebar:
57
+ choice = option_menu("LastMinute - Diagnosis",
58
+ ["About", "Visualize WSI slide", "Cancer Detection", "Contact"],
59
+ icons=['house', 'upload', 'activity', 'person lines fill'],
60
+ menu_icon="app-indicator", default_index=0,
61
+ styles={
62
+ # "container": {"padding": "5!important", "background-color": "#fafafa", },
63
+ "container": {"border-radius": ".0rem"},
64
+ # "icon": {"color": "orange", "font-size": "25px"},
65
+ # "nav-link": {"font-size": "16px", "text-align": "left", "margin": "0px",
66
+ # "--hover-color": "#eee"},
67
+ # "nav-link-selected": {"background-color": "#02ab21"},
68
+ }
69
+ )
70
+ st.sidebar.markdown(
71
+ """
72
+ <style>
73
+ .aligncenter {
74
+ text-align: center;
75
+ }
76
+ </style>
77
+ <p class="aligncenter">
78
+ <a href="https://twitter.com/chris_cancedda" target="_blank">
79
+ <img src="https://img.shields.io/twitter/follow/chris_cancedda?style=social"/>
80
+ </a>
81
+ </p>
82
+ """,
83
+ unsafe_allow_html=True,
84
+ )
85
+
86
+
87
+
88
+ if choice == "About":
89
+ st.title(choice)
90
+
91
+
92
+
93
+ if choice == "Visualize WSI slide":
94
+ st.title(choice)
95
+ st.markdown(VISUALIZE_TEXT)
96
+
97
+ uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)")
98
+ if uploaded_file is not None:
99
+ ori = openslide.OpenSlide(uploaded_file.name)
100
+ width, height = ori.dimensions
101
+
102
+ REDUCTION_FACTOR = 20
103
+ w, h = int(width/512), int(height/512)
104
+ w_r, h_r = int(width/20), int(height/20)
105
+ resized_img = ori.get_thumbnail((w_r,h_r))
106
+ resized_img = resized_img.resize((w_r,h_r))
107
+ ratio_w, ratio_h = width/resized_img.width, height/resized_img.height
108
+ #print('ratios ', ratio_w, ratio_h)
109
+ w_s, h_s = float(512/REDUCTION_FACTOR), float(512/REDUCTION_FACTOR)
110
+ st.image(resized_img, use_column_width='never')
111
+
112
+ if choice == "Cancer Detection":
113
+ state = dict()
114
+
115
+ st.title(choice)
116
+ st.markdown(DETECT_TEXT)
117
+ uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)")
118
+ if uploaded_file is not None:
119
+ # To read file as bytes:
120
+ #print(uploaded_file)
121
+ with open(os.path.join(uploaded_file.name),"wb") as f:
122
+ f.write(uploaded_file.getbuffer())
123
+ with st.spinner(text="Computation is running"):
124
+ predicted_class, viz_dict = predictor.predict(uploaded_file.name)
125
+ st.info('Computation completed.')
126
+ st.header(f'Predicted to be: {predicted_class}')
127
+ st.text('Heatmap of the areas that show markers correlated with the disease.\nIncreasing red tones represent higher likelihood that the area is affected')
128
+ state['cur'] = predicted_class
129
+ mapper = {'ORI': predicted_class, predicted_class:'ORI'}
130
+ readable_mapper = {'ORI': 'Original', predicted_class :'Disease heatmap' }
131
+ #def fn():
132
+ # st.image(viz_dict[mapper[state['cur']]], use_column_width='never', channels='BGR')
133
+ # state['cur'] = mapper[state['cur']]
134
+ # return
135
+
136
+ #st.button(f'See {readable_mapper[mapper[state["cur"]] ]}', on_click=fn )
137
+ #st.image(viz_dict[state['cur']], use_column_width='never', channels='BGR')
138
+ st.image([viz_dict[state['cur']],viz_dict['ORI']], caption=['Original', f'{predicted_class} heatmap'] ,channels='BGR'
139
+ # use_column_width='never',
140
+ )
141
+
142
+
143
+ if choice == "Contact":
144
+ st.title(choice)
145
+ st.markdown(CONTACT_TEXT)
feature_extractor/.ipynb_checkpoints/weights_check-checkpoint.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
feature_extractor/__init__.py ADDED
File without changes
feature_extractor/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (179 Bytes). View file
 
feature_extractor/__pycache__/build_graph_utils.cpython-38.pyc ADDED
Binary file (3.48 kB). View file
 
feature_extractor/__pycache__/build_graphs.cpython-38.pyc ADDED
Binary file (6.45 kB). View file
 
feature_extractor/__pycache__/cl.cpython-38.pyc ADDED
Binary file (3.05 kB). View file
 
feature_extractor/__pycache__/simclr.cpython-36.pyc ADDED
Binary file (4.38 kB). View file
 
feature_extractor/__pycache__/simclr.cpython-38.pyc ADDED
Binary file (4.5 kB). View file
 
feature_extractor/build_graph_utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.data import DataLoader
5
+ import torchvision.models as models
6
+ import torchvision.transforms.functional as VF
7
+ from torchvision import transforms
8
+
9
+ import sys, argparse, os, glob
10
+ import pandas as pd
11
+ import numpy as np
12
+ from PIL import Image
13
+ from collections import OrderedDict
14
+
15
+ class ToPIL(object):
16
+ def __call__(self, sample):
17
+ img = sample
18
+ img = transforms.functional.to_pil_image(img)
19
+ return img
20
+
21
+ class BagDataset():
22
+ def __init__(self, csv_file, transform=None):
23
+ self.files_list = csv_file
24
+ self.transform = transform
25
+ def __len__(self):
26
+ return len(self.files_list)
27
+ def __getitem__(self, idx):
28
+ temp_path = self.files_list[idx]
29
+ img = os.path.join(temp_path)
30
+ img = Image.open(img)
31
+ img = img.resize((224, 224))
32
+ sample = {'input': img}
33
+
34
+ if self.transform:
35
+ sample = self.transform(sample)
36
+ return sample
37
+
38
+ class ToTensor(object):
39
+ def __call__(self, sample):
40
+ img = sample['input']
41
+ img = VF.to_tensor(img)
42
+ return {'input': img}
43
+
44
+ class Compose(object):
45
+ def __init__(self, transforms):
46
+ self.transforms = transforms
47
+
48
+ def __call__(self, img):
49
+ for t in self.transforms:
50
+ img = t(img)
51
+ return img
52
+
53
+ def save_coords(txt_file, csv_file_path):
54
+ for path in csv_file_path:
55
+ x, y = path.split('/')[-1].split('.')[0].split('_')
56
+ txt_file.writelines(str(x) + '\t' + str(y) + '\n')
57
+ txt_file.close()
58
+
59
+ def adj_matrix(csv_file_path, output, device='cpu'):
60
+ total = len(csv_file_path)
61
+ adj_s = np.zeros((total, total))
62
+
63
+ for i in range(total-1):
64
+ path_i = csv_file_path[i]
65
+ x_i, y_i = path_i.split('/')[-1].split('.')[0].split('_')
66
+ for j in range(i+1, total):
67
+ # sptial
68
+ path_j = csv_file_path[j]
69
+ x_j, y_j = path_j.split('/')[-1].split('.')[0].split('_')
70
+ if abs(int(x_i)-int(x_j)) <=1 and abs(int(y_i)-int(y_j)) <= 1:
71
+ adj_s[i][j] = 1
72
+ adj_s[j][i] = 1
73
+
74
+ adj_s = torch.from_numpy(adj_s)
75
+ adj_s = adj_s.to(device)
76
+
77
+ return adj_s
78
+
79
+ def bag_dataset(args, csv_file_path):
80
+ transformed_dataset = BagDataset(csv_file=csv_file_path,
81
+ transform=Compose([
82
+ ToTensor()
83
+ ]))
84
+ dataloader = DataLoader(transformed_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False)
85
+ return dataloader, len(transformed_dataset)
feature_extractor/build_graphs.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from cl import IClassifier
3
+ from build_graph_utils import *
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.data import DataLoader
7
+ import torchvision.models as models
8
+ import torchvision.transforms.functional as VF
9
+ from torchvision import transforms
10
+
11
+ import sys, argparse, os, glob
12
+ import pandas as pd
13
+ import numpy as np
14
+ from PIL import Image
15
+ from collections import OrderedDict
16
+
17
+
18
+
19
+ def compute_feats(args, bags_list, i_classifier, device, save_path=None, whole_slide_path=None):
20
+ num_bags = len(bags_list)
21
+ Tensor = torch.FloatTensor
22
+ for i in range(0, num_bags):
23
+ feats_list = []
24
+ if args.magnification == '20x':
25
+ glob_path = os.path.join(bags_list[i], '*.jpeg')
26
+ csv_file_path = glob.glob(glob_path)
27
+ # line below was in the original version, commented due to errror with current version
28
+ #file_name = bags_list[i].split('/')[-3].split('_')[0]
29
+
30
+ file_name = glob_path.split('/')[-3].split('_')[0]
31
+
32
+ if args.magnification == '5x' or args.magnification == '10x':
33
+ csv_file_path = glob.glob(os.path.join(bags_list[i], '*.jpg'))
34
+
35
+ dataloader, bag_size = bag_dataset(args, csv_file_path)
36
+ print('{} files to be processed: {}'.format(len(csv_file_path), file_name))
37
+
38
+ if os.path.isdir(os.path.join(save_path, 'simclr_files', file_name)) or len(csv_file_path) < 1:
39
+ print('alreday exists')
40
+ continue
41
+ with torch.no_grad():
42
+ for iteration, batch in enumerate(dataloader):
43
+ patches = batch['input'].float().to(device)
44
+ feats, classes = i_classifier(patches)
45
+ #feats = feats.cpu().numpy()
46
+ feats_list.extend(feats)
47
+
48
+ os.makedirs(os.path.join(save_path, 'simclr_files', file_name), exist_ok=True)
49
+
50
+ txt_file = open(os.path.join(save_path, 'simclr_files', file_name, 'c_idx.txt'), "w+")
51
+ save_coords(txt_file, csv_file_path)
52
+ # save node features
53
+ output = torch.stack(feats_list, dim=0).to(device)
54
+ torch.save(output, os.path.join(save_path, 'simclr_files', file_name, 'features.pt'))
55
+ # save adjacent matrix
56
+ adj_s = adj_matrix(csv_file_path, output, device=device)
57
+ torch.save(adj_s, os.path.join(save_path, 'simclr_files', file_name, 'adj_s.pt'))
58
+
59
+ print('\r Computed: {}/{}'.format(i+1, num_bags))
60
+
61
+
62
+ def main():
63
+ parser = argparse.ArgumentParser(description='Compute TCGA features from SimCLR embedder')
64
+ parser.add_argument('--num_classes', default=2, type=int, help='Number of output classes')
65
+ parser.add_argument('--num_feats', default=512, type=int, help='Feature size')
66
+ parser.add_argument('--batch_size', default=128, type=int, help='Batch size of dataloader')
67
+ parser.add_argument('--num_workers', default=0, type=int, help='Number of threads for datalodaer')
68
+ parser.add_argument('--dataset', default=None, type=str, help='path to patches')
69
+ parser.add_argument('--backbone', default='resnet18', type=str, help='Embedder backbone')
70
+ parser.add_argument('--magnification', default='20x', type=str, help='Magnification to compute features')
71
+ parser.add_argument('--weights', default=None, type=str, help='path to the pretrained weights')
72
+ parser.add_argument('--output', default=None, type=str, help='path to the output graph folder')
73
+ args = parser.parse_args()
74
+
75
+ if args.backbone == 'resnet18':
76
+ resnet = models.resnet18(pretrained=False, norm_layer=nn.InstanceNorm2d)
77
+ num_feats = 512
78
+ if args.backbone == 'resnet34':
79
+ resnet = models.resnet34(pretrained=False, norm_layer=nn.InstanceNorm2d)
80
+ num_feats = 512
81
+ if args.backbone == 'resnet50':
82
+ resnet = models.resnet50(pretrained=False, norm_layer=nn.InstanceNorm2d)
83
+ num_feats = 2048
84
+ if args.backbone == 'resnet101':
85
+ resnet = models.resnet101(pretrained=False, norm_layer=nn.InstanceNorm2d)
86
+ num_feats = 2048
87
+ for param in resnet.parameters():
88
+ param.requires_grad = False
89
+ resnet.fc = nn.Identity()
90
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
91
+ print("Running on:", device)
92
+ i_classifier = IClassifier(resnet, num_feats, output_class=args.num_classes).to(device)
93
+
94
+ # load feature extractor
95
+ if args.weights is None:
96
+ print('No feature extractor')
97
+ return
98
+ state_dict_weights = torch.load(args.weights)
99
+ state_dict_init = i_classifier.state_dict()
100
+ new_state_dict = OrderedDict()
101
+ for (k, v), (k_0, v_0) in zip(state_dict_weights.items(), state_dict_init.items()):
102
+ if 'features' not in k:
103
+ continue
104
+ name = k_0
105
+ new_state_dict[name] = v
106
+ i_classifier.load_state_dict(new_state_dict, strict=False)
107
+
108
+ os.makedirs(args.output, exist_ok=True)
109
+ bags_list = glob.glob(args.dataset)
110
+ print(bags_list)
111
+ compute_feats(args, bags_list, i_classifier, device, args.output)
112
+
113
+ if __name__ == '__main__':
114
+ main()
feature_extractor/cl.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Variable
5
+
6
+ class FCLayer(nn.Module):
7
+ def __init__(self, in_size, out_size=1):
8
+ super(FCLayer, self).__init__()
9
+ self.fc = nn.Sequential(nn.Linear(in_size, out_size))
10
+ def forward(self, feats):
11
+ x = self.fc(feats)
12
+ return feats, x
13
+
14
+ class IClassifier(nn.Module):
15
+ def __init__(self, feature_extractor, feature_size, output_class):
16
+ super(IClassifier, self).__init__()
17
+
18
+ self.feature_extractor = feature_extractor
19
+ self.fc = nn.Linear(feature_size, output_class)
20
+
21
+
22
+ def forward(self, x):
23
+ device = x.device
24
+ feats = self.feature_extractor(x) # N x K
25
+ c = self.fc(feats.view(feats.shape[0], -1)) # N x C
26
+ return feats.view(feats.shape[0], -1), c
27
+
28
+ class BClassifier(nn.Module):
29
+ def __init__(self, input_size, output_class, dropout_v=0.0): # K, L, N
30
+ super(BClassifier, self).__init__()
31
+ self.q = nn.Linear(input_size, 128)
32
+ self.v = nn.Sequential(
33
+ nn.Dropout(dropout_v),
34
+ nn.Linear(input_size, input_size)
35
+ )
36
+
37
+ ### 1D convolutional layer that can handle multiple class (including binary)
38
+ self.fcc = nn.Conv1d(output_class, output_class, kernel_size=input_size)
39
+
40
+ def forward(self, feats, c): # N x K, N x C
41
+ device = feats.device
42
+ V = self.v(feats) # N x V, unsorted
43
+ Q = self.q(feats).view(feats.shape[0], -1) # N x Q, unsorted
44
+
45
+ # handle multiple classes without for loop
46
+ _, m_indices = torch.sort(c, 0, descending=True) # sort class scores along the instance dimension, m_indices in shape N x C
47
+ m_feats = torch.index_select(feats, dim=0, index=m_indices[0, :]) # select critical instances, m_feats in shape C x K
48
+ q_max = self.q(m_feats) # compute queries of critical instances, q_max in shape C x Q
49
+ A = torch.mm(Q, q_max.transpose(0, 1)) # compute inner product of Q to each entry of q_max, A in shape N x C, each column contains unnormalized attention scores
50
+ A = F.softmax( A / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device)), 0) # normalize attention scores, A in shape N x C,
51
+ B = torch.mm(A.transpose(0, 1), V) # compute bag representation, B in shape C x V
52
+
53
+
54
+ # for i in range(c.shape[1]):
55
+ # _, indices = torch.sort(c[:, i], 0, True)
56
+ # feats = torch.index_select(feats, 0, indices) # N x K, sorted
57
+ # q_max = self.q(feats[0].view(1, -1)) # 1 x 1 x Q
58
+ # temp = torch.mm(Q, q_max.view(-1, 1)) / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device))
59
+ # if i == 0:
60
+ # A = F.softmax(temp, 0) # N x 1
61
+ # B = torch.sum(torch.mul(A, V), 0).view(1, -1) # 1 x V
62
+ # else:
63
+ # temp = F.softmax(temp, 0) # N x 1
64
+ # A = torch.cat((A, temp), 1) # N x C
65
+ # B = torch.cat((B, torch.sum(torch.mul(temp, V), 0).view(1, -1)), 0) # C x V -> 1 x C x V
66
+
67
+ B = B.view(1, B.shape[0], B.shape[1]) # 1 x C x V
68
+ C = self.fcc(B) # 1 x C x 1
69
+ C = C.view(1, -1)
70
+ return C, A, B
71
+
72
+ class MILNet(nn.Module):
73
+ def __init__(self, i_classifier, b_classifier):
74
+ super(MILNet, self).__init__()
75
+ self.i_classifier = i_classifier
76
+ self.b_classifier = b_classifier
77
+
78
+ def forward(self, x):
79
+ feats, classes = self.i_classifier(x)
80
+ prediction_bag, A, B = self.b_classifier(feats, classes)
81
+
82
+ return classes, prediction_bag, A, B
83
+
feature_extractor/config.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 256
2
+ epochs: 20
3
+ eval_every_n_epochs: 1
4
+ fine_tune_from: ''
5
+ log_every_n_steps: 25
6
+ weight_decay: 10e-6
7
+ fp16_precision: False
8
+ n_gpu: 2
9
+ gpu_ids: (0,1)
10
+
11
+ model:
12
+ out_dim: 512
13
+ base_model: "resnet18"
14
+
15
+ dataset:
16
+ s: 1
17
+ input_shape: (224,224,3)
18
+ num_workers: 10
19
+ valid_size: 0.1
20
+
21
+ loss:
22
+ temperature: 0.5
23
+ use_cosine_similarity: True
feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-36.pyc ADDED
Binary file (3.83 kB). View file
 
feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-38.pyc ADDED
Binary file (4 kB). View file
 
feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-36.pyc ADDED
Binary file (896 Bytes). View file
 
feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-38.pyc ADDED
Binary file (932 Bytes). View file
 
feature_extractor/data_aug/dataset_wrapper.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torch.utils.data import DataLoader
3
+ from torch.utils.data.sampler import SubsetRandomSampler
4
+ import torchvision.transforms as transforms
5
+ from data_aug.gaussian_blur import GaussianBlur
6
+ from torchvision import datasets
7
+ import pandas as pd
8
+ from PIL import Image
9
+ from skimage import io, img_as_ubyte
10
+
11
+ np.random.seed(0)
12
+
13
+ class Dataset():
14
+ def __init__(self, csv_file, transform=None):
15
+ lines = []
16
+ with open(csv_file) as f:
17
+ for line in f:
18
+ line = line.rstrip().strip()
19
+ lines.append(line)
20
+ self.files_list = lines#pd.read_csv(csv_file)
21
+ self.transform = transform
22
+ def __len__(self):
23
+ return len(self.files_list)
24
+ def __getitem__(self, idx):
25
+ temp_path = self.files_list[idx]# self.files_list.iloc[idx, 0]
26
+ img = Image.open(temp_path)
27
+ img = transforms.functional.to_tensor(img)
28
+ if self.transform:
29
+ sample = self.transform(img)
30
+ return sample
31
+
32
+ class ToPIL(object):
33
+ def __call__(self, sample):
34
+ img = sample
35
+ img = transforms.functional.to_pil_image(img)
36
+ return img
37
+
38
+ class DataSetWrapper(object):
39
+
40
+ def __init__(self, batch_size, num_workers, valid_size, input_shape, s):
41
+ self.batch_size = batch_size
42
+ self.num_workers = num_workers
43
+ self.valid_size = valid_size
44
+ self.s = s
45
+ self.input_shape = eval(input_shape)
46
+
47
+ def get_data_loaders(self):
48
+ data_augment = self._get_simclr_pipeline_transform()
49
+ train_dataset = Dataset(csv_file='all_patches.csv', transform=SimCLRDataTransform(data_augment))
50
+ train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset)
51
+ return train_loader, valid_loader
52
+
53
+ def _get_simclr_pipeline_transform(self):
54
+ # get a set of data augmentation transformations as described in the SimCLR paper.
55
+ color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s)
56
+ data_transforms = transforms.Compose([ToPIL(),
57
+ # transforms.RandomResizedCrop(size=self.input_shape[0]),
58
+ transforms.Resize((self.input_shape[0],self.input_shape[1])),
59
+ transforms.RandomHorizontalFlip(),
60
+ transforms.RandomApply([color_jitter], p=0.8),
61
+ transforms.RandomGrayscale(p=0.2),
62
+ GaussianBlur(kernel_size=int(0.06 * self.input_shape[0])),
63
+ transforms.ToTensor()])
64
+ return data_transforms
65
+
66
+ def get_train_validation_data_loaders(self, train_dataset):
67
+ # obtain training indices that will be used for validation
68
+ num_train = len(train_dataset)
69
+ indices = list(range(num_train))
70
+ np.random.shuffle(indices)
71
+
72
+ split = int(np.floor(self.valid_size * num_train))
73
+ train_idx, valid_idx = indices[split:], indices[:split]
74
+
75
+ # define samplers for obtaining training and validation batches
76
+ train_sampler = SubsetRandomSampler(train_idx)
77
+ valid_sampler = SubsetRandomSampler(valid_idx)
78
+
79
+ train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler,
80
+ num_workers=self.num_workers, drop_last=True, shuffle=False)
81
+ valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler,
82
+ num_workers=self.num_workers, drop_last=True)
83
+ return train_loader, valid_loader
84
+
85
+
86
+ class SimCLRDataTransform(object):
87
+ def __init__(self, transform):
88
+ self.transform = transform
89
+
90
+ def __call__(self, sample):
91
+ xi = self.transform(sample)
92
+ xj = self.transform(sample)
93
+ return xi, xj
feature_extractor/data_aug/gaussian_blur.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ np.random.seed(0)
5
+
6
+
7
+ class GaussianBlur(object):
8
+ # Implements Gaussian blur as described in the SimCLR paper
9
+ def __init__(self, kernel_size, min=0.1, max=2.0):
10
+ self.min = min
11
+ self.max = max
12
+ # kernel size is set to be 10% of the image height/width
13
+ self.kernel_size = kernel_size
14
+
15
+ def __call__(self, sample):
16
+ sample = np.array(sample)
17
+
18
+ # blur the image with a 50% chance
19
+ prob = np.random.random_sample()
20
+
21
+ if prob < 0.5:
22
+ # print(self.kernel_size)
23
+ sigma = (self.max - self.min) * np.random.random_sample() + self.min
24
+ sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
25
+
26
+ return sample
feature_extractor/load_patches.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os, glob
3
+ import argparse
4
+
5
+ def main():
6
+ parser = argparse.ArgumentParser()
7
+ parser.add_argument('--data_path', type=str)
8
+ args = parser.parse_args()
9
+
10
+ wsi_slides_paths = []
11
+
12
+
13
+ def r(dirpath):
14
+ for file in os.listdir(dirpath):
15
+ path = os.path.join(dirpath, file)
16
+ if os.path.isfile(path) and file.endswith(".svs"):
17
+ wsi_slides_paths.append(path)
18
+ elif os.path.isdir(path):
19
+ r(path)
20
+ def r(dirpath):
21
+ for path in glob.glob(os.path.join(dirpath, '*','*.svs') ):#os.listdir(dirpath):
22
+ if os.path.isfile(path):
23
+ wsi_slides_paths.append(path)
24
+ def r(dirpath):
25
+ for path in glob.glob(os.path.join(dirpath, '*', '*', '*.jpeg') ):#os.listdir(dirpath):
26
+ if os.path.isfile(path):
27
+ wsi_slides_paths.append(path)
28
+ r(args.data_path)
29
+ with open('all_patches.csv', 'w') as f:
30
+ for filepath in wsi_slides_paths:
31
+ f.write(f'{filepath}\n')
32
+
33
+
34
+
35
+
36
+ if __name__ == "__main__":
37
+ main()
feature_extractor/loss/__pycache__/nt_xent.cpython-36.pyc ADDED
Binary file (2.45 kB). View file
 
feature_extractor/loss/__pycache__/nt_xent.cpython-38.pyc ADDED
Binary file (2.49 kB). View file
 
feature_extractor/loss/nt_xent.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class NTXentLoss(torch.nn.Module):
6
+
7
+ def __init__(self, device, batch_size, temperature, use_cosine_similarity):
8
+ super(NTXentLoss, self).__init__()
9
+ self.batch_size = batch_size
10
+ self.temperature = temperature
11
+ self.device = device
12
+ self.softmax = torch.nn.Softmax(dim=-1)
13
+ self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
14
+ self.similarity_function = self._get_similarity_function(use_cosine_similarity)
15
+ self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
16
+
17
+ def _get_similarity_function(self, use_cosine_similarity):
18
+ if use_cosine_similarity:
19
+ self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
20
+ return self._cosine_simililarity
21
+ else:
22
+ return self._dot_simililarity
23
+
24
+ def _get_correlated_mask(self):
25
+ diag = np.eye(2 * self.batch_size)
26
+ l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
27
+ l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
28
+ mask = torch.from_numpy((diag + l1 + l2))
29
+ mask = (1 - mask).type(torch.bool)
30
+ return mask.to(self.device)
31
+
32
+ @staticmethod
33
+ def _dot_simililarity(x, y):
34
+ v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
35
+ # x shape: (N, 1, C)
36
+ # y shape: (1, C, 2N)
37
+ # v shape: (N, 2N)
38
+ return v
39
+
40
+ def _cosine_simililarity(self, x, y):
41
+ # x shape: (N, 1, C)
42
+ # y shape: (1, 2N, C)
43
+ # v shape: (N, 2N)
44
+ v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
45
+ return v
46
+
47
+ def forward(self, zis, zjs):
48
+ representations = torch.cat([zjs, zis], dim=0)
49
+
50
+ similarity_matrix = self.similarity_function(representations, representations)
51
+
52
+ # filter out the scores from the positive samples
53
+ l_pos = torch.diag(similarity_matrix, self.batch_size)
54
+ r_pos = torch.diag(similarity_matrix, -self.batch_size)
55
+ positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
56
+
57
+ negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)
58
+
59
+ logits = torch.cat((positives, negatives), dim=1)
60
+ logits /= self.temperature
61
+
62
+ labels = torch.zeros(2 * self.batch_size).to(self.device).long()
63
+ loss = self.criterion(logits, labels)
64
+
65
+ return loss / (2 * self.batch_size)
feature_extractor/models/__init__.py ADDED
File without changes
feature_extractor/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (186 Bytes). View file
 
feature_extractor/models/__pycache__/resnet_simclr.cpython-36.pyc ADDED
Binary file (1.51 kB). View file
 
feature_extractor/models/__pycache__/resnet_simclr.cpython-38.pyc ADDED
Binary file (1.55 kB). View file
 
feature_extractor/models/baseline_encoder.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+
6
+
7
+ class Encoder(nn.Module):
8
+ def __init__(self, out_dim=64):
9
+ super(Encoder, self).__init__()
10
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
11
+ self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
12
+ self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
13
+ self.conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
14
+ self.pool = nn.MaxPool2d(2, 2)
15
+
16
+ # projection MLP
17
+ self.l1 = nn.Linear(64, 64)
18
+ self.l2 = nn.Linear(64, out_dim)
19
+
20
+ def forward(self, x):
21
+ x = self.conv1(x)
22
+ x = F.relu(x)
23
+ x = self.pool(x)
24
+
25
+ x = self.conv2(x)
26
+ x = F.relu(x)
27
+ x = self.pool(x)
28
+
29
+ x = self.conv3(x)
30
+ x = F.relu(x)
31
+ x = self.pool(x)
32
+
33
+ x = self.conv4(x)
34
+ x = F.relu(x)
35
+ x = self.pool(x)
36
+
37
+ h = torch.mean(x, dim=[2, 3])
38
+
39
+ x = self.l1(h)
40
+ x = F.relu(x)
41
+ x = self.l2(x)
42
+
43
+ return h, x
feature_extractor/models/resnet_simclr.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import torchvision.models as models
4
+
5
+
6
+ class ResNetSimCLR(nn.Module):
7
+
8
+ def __init__(self, base_model, out_dim):
9
+ super(ResNetSimCLR, self).__init__()
10
+ self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, norm_layer=nn.InstanceNorm2d),
11
+ "resnet50": models.resnet50(pretrained=False)}
12
+
13
+ resnet = self._get_basemodel(base_model)
14
+ num_ftrs = resnet.fc.in_features
15
+
16
+ self.features = nn.Sequential(*list(resnet.children())[:-1])
17
+
18
+ # projection MLP
19
+ self.l1 = nn.Linear(num_ftrs, num_ftrs)
20
+ self.l2 = nn.Linear(num_ftrs, out_dim)
21
+
22
+ def _get_basemodel(self, model_name):
23
+ try:
24
+ model = self.resnet_dict[model_name]
25
+ print("Feature extractor:", model_name)
26
+ return model
27
+ except:
28
+ raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
29
+
30
+ def forward(self, x):
31
+ h = self.features(x)
32
+ h = h.squeeze()
33
+
34
+ x = self.l1(h)
35
+ x = F.relu(x)
36
+ x = self.l2(x)
37
+ return h, x
feature_extractor/run.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from simclr import SimCLR
2
+ import yaml
3
+ from data_aug.dataset_wrapper import DataSetWrapper
4
+ import os, glob
5
+ import pandas as pd
6
+ import argparse
7
+
8
+ def main():
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--magnification', type=str, default='20x')
11
+ parser.add_argument('--dest_weights', type=str)
12
+ args = parser.parse_args()
13
+ config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
14
+ dataset = DataSetWrapper(config['batch_size'], **config['dataset'])
15
+
16
+ simclr = SimCLR(dataset, config, args)
17
+ simclr.train()
18
+
19
+
20
+ if __name__ == "__main__":
21
+ main()
feature_extractor/simclr.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from models.resnet_simclr import ResNetSimCLR
3
+ from torch.utils.tensorboard import SummaryWriter
4
+ import torch.nn.functional as F
5
+ from loss.nt_xent import NTXentLoss
6
+ import os
7
+ import shutil
8
+ import sys
9
+
10
+ apex_support = False
11
+ try:
12
+ sys.path.append('./apex')
13
+ from apex import amp
14
+
15
+ apex_support = True
16
+ except:
17
+ print("Please install apex for mixed precision training from: https://github.com/NVIDIA/apex")
18
+ apex_support = False
19
+
20
+ import numpy as np
21
+
22
+ torch.manual_seed(0)
23
+
24
+
25
+ def _save_config_file(model_checkpoints_folder):
26
+ if not os.path.exists(model_checkpoints_folder):
27
+ os.makedirs(model_checkpoints_folder)
28
+ shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml'))
29
+
30
+
31
+ class SimCLR(object):
32
+
33
+ def __init__(self, dataset, config, args=None):
34
+ self.config = config
35
+ self.device = self._get_device()
36
+ self.writer = SummaryWriter()
37
+ self.dataset = dataset
38
+ self.nt_xent_criterion = NTXentLoss(self.device, config['batch_size'], **config['loss'])
39
+ self.args = args
40
+ def _get_device(self):
41
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
42
+ print("Running on:", device)
43
+ return device
44
+
45
+ def _step(self, model, xis, xjs, n_iter):
46
+
47
+ # get the representations and the projections
48
+ ris, zis = model(xis) # [N,C]
49
+
50
+ # get the representations and the projections
51
+ rjs, zjs = model(xjs) # [N,C]
52
+
53
+ # normalize projection feature vectors
54
+ zis = F.normalize(zis, dim=1)
55
+ zjs = F.normalize(zjs, dim=1)
56
+
57
+ loss = self.nt_xent_criterion(zis, zjs)
58
+ return loss
59
+
60
+ def train(self):
61
+
62
+ train_loader, valid_loader = self.dataset.get_data_loaders()
63
+
64
+ model = ResNetSimCLR(**self.config["model"])# .to(self.device)
65
+ if self.config['n_gpu'] > 1:
66
+ model = torch.nn.DataParallel(model, device_ids=eval(self.config['gpu_ids']))
67
+ model = self._load_pre_trained_weights(model)
68
+ model = model.to(self.device)
69
+
70
+
71
+ optimizer = torch.optim.Adam(model.parameters(), 1e-5, weight_decay=eval(self.config['weight_decay']))
72
+
73
+ # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
74
+ # last_epoch=-1)
75
+
76
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.config['epochs'], eta_min=0,
77
+ last_epoch=-1)
78
+
79
+
80
+ if apex_support and self.config['fp16_precision']:
81
+ model, optimizer = amp.initialize(model, optimizer,
82
+ opt_level='O2',
83
+ keep_batchnorm_fp32=True)
84
+
85
+ if self.args is None:
86
+ model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
87
+ else:
88
+ model_checkpoints_folder = self.args.dest_weights#os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH']
89
+ model_checkpoints_folder = os.path.dirname(model_checkpoints_folder)
90
+ # save config file
91
+ _save_config_file(model_checkpoints_folder)
92
+
93
+ n_iter = 0
94
+ valid_n_iter = 0
95
+ best_valid_loss = np.inf
96
+
97
+ for epoch_counter in range(self.config['epochs']):
98
+ for (xis, xjs) in train_loader:
99
+ optimizer.zero_grad()
100
+ xis = xis.to(self.device)
101
+ xjs = xjs.to(self.device)
102
+
103
+ loss = self._step(model, xis, xjs, n_iter)
104
+
105
+ if n_iter % self.config['log_every_n_steps'] == 0:
106
+ self.writer.add_scalar('train_loss', loss, global_step=n_iter)
107
+ print("[%d/%d] step: %d train_loss: %.3f" % (epoch_counter, self.config['epochs'], n_iter, loss))
108
+
109
+ if apex_support and self.config['fp16_precision']:
110
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
111
+ scaled_loss.backward()
112
+ else:
113
+ loss.backward()
114
+
115
+ optimizer.step()
116
+ n_iter += 1
117
+
118
+ # validate the model if requested
119
+ if epoch_counter % self.config['eval_every_n_epochs'] == 0:
120
+ valid_loss = self._validate(model, valid_loader)
121
+ print("[%d/%d] val_loss: %.3f" % (epoch_counter, self.config['epochs'], valid_loss))
122
+ if valid_loss < best_valid_loss:
123
+ # save the model weights
124
+ best_valid_loss = valid_loss
125
+ torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
126
+ print('saved')
127
+
128
+ self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
129
+ valid_n_iter += 1
130
+
131
+ # warmup for the first 10 epochs
132
+ if epoch_counter >= 10:
133
+ scheduler.step()
134
+ self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter)
135
+
136
+ def _load_pre_trained_weights(self, model):
137
+ try:
138
+ checkpoints_folder = os.path.join('./runs', self.config['fine_tune_from'], 'checkpoints')
139
+ state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'))
140
+ model.load_state_dict(state_dict)
141
+ print("Loaded pre-trained model with success.")
142
+ except FileNotFoundError:
143
+ print("Pre-trained weights not found. Training from scratch.")
144
+
145
+ return model
146
+
147
+ def _validate(self, model, valid_loader):
148
+
149
+ # validation steps
150
+ with torch.no_grad():
151
+ model.eval()
152
+
153
+ valid_loss = 0.0
154
+ counter = 0
155
+
156
+ for (xis, xjs) in valid_loader:
157
+ xis = xis.to(self.device)
158
+ xjs = xjs.to(self.device)
159
+
160
+ loss = self._step(model, xis, xjs, counter)
161
+ valid_loss += loss.item()
162
+ counter += 1
163
+ valid_loss /= counter
164
+ model.train()
165
+ return valid_loss
feature_extractor/viewer.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ #
3
+ # deepzoom_server - Example web application for serving whole-slide images
4
+ #
5
+ # Copyright (c) 2010-2015 Carnegie Mellon University
6
+ #
7
+ # This library is free software; you can redistribute it and/or modify it
8
+ # under the terms of version 2.1 of the GNU Lesser General Public License
9
+ # as published by the Free Software Foundation.
10
+ #
11
+ # This library is distributed in the hope that it will be useful, but
12
+ # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
13
+ # or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
14
+ # License for more details.
15
+ #
16
+ # You should have received a copy of the GNU Lesser General Public License
17
+ # along with this library; if not, write to the Free Software Foundation,
18
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
19
+ #
20
+
21
+ from io import BytesIO
22
+ from optparse import OptionParser
23
+ import os
24
+ import re
25
+ from unicodedata import normalize
26
+
27
+ from flask import Flask, abort, make_response, render_template, url_for
28
+
29
+ if os.name == 'nt':
30
+ _dll_path = os.getenv('OPENSLIDE_PATH')
31
+ if _dll_path is not None:
32
+ if hasattr(os, 'add_dll_directory'):
33
+ # Python >= 3.8
34
+ with os.add_dll_directory(_dll_path):
35
+ import openslide
36
+ else:
37
+ # Python < 3.8
38
+ _orig_path = os.environ.get('PATH', '')
39
+ os.environ['PATH'] = _orig_path + ';' + _dll_path
40
+ import openslide
41
+
42
+ os.environ['PATH'] = _orig_path
43
+ else:
44
+ import openslide
45
+
46
+ from openslide import ImageSlide, open_slide
47
+ from openslide.deepzoom import DeepZoomGenerator
48
+
49
+ DEEPZOOM_SLIDE = None
50
+ DEEPZOOM_FORMAT = 'jpeg'
51
+ DEEPZOOM_TILE_SIZE = 254
52
+ DEEPZOOM_OVERLAP = 1
53
+ DEEPZOOM_LIMIT_BOUNDS = True
54
+ DEEPZOOM_TILE_QUALITY = 75
55
+ SLIDE_NAME = 'slide'
56
+
57
+ app = Flask(__name__)
58
+ app.config.from_object(__name__)
59
+ app.config.from_envvar('DEEPZOOM_TILER_SETTINGS', silent=True)
60
+
61
+
62
+ @app.before_first_request
63
+ def load_slide():
64
+ slidefile = app.config['DEEPZOOM_SLIDE']
65
+ if slidefile is None:
66
+ raise ValueError('No slide file specified')
67
+ config_map = {
68
+ 'DEEPZOOM_TILE_SIZE': 'tile_size',
69
+ 'DEEPZOOM_OVERLAP': 'overlap',
70
+ 'DEEPZOOM_LIMIT_BOUNDS': 'limit_bounds',
71
+ }
72
+ opts = {v: app.config[k] for k, v in config_map.items()}
73
+ slide = open_slide(slidefile)
74
+ app.slides = {SLIDE_NAME: DeepZoomGenerator(slide, **opts)}
75
+ app.associated_images = []
76
+ app.slide_properties = slide.properties
77
+ for name, image in slide.associated_images.items():
78
+ app.associated_images.append(name)
79
+ slug = slugify(name)
80
+ app.slides[slug] = DeepZoomGenerator(ImageSlide(image), **opts)
81
+ try:
82
+ mpp_x = slide.properties[openslide.PROPERTY_NAME_MPP_X]
83
+ mpp_y = slide.properties[openslide.PROPERTY_NAME_MPP_Y]
84
+ app.slide_mpp = (float(mpp_x) + float(mpp_y)) / 2
85
+ except (KeyError, ValueError):
86
+ app.slide_mpp = 0
87
+
88
+
89
+ @app.route('/')
90
+ def index():
91
+ slide_url = url_for('dzi', slug=SLIDE_NAME)
92
+ associated_urls = {
93
+ name: url_for('dzi', slug=slugify(name)) for name in app.associated_images
94
+ }
95
+ return render_template(
96
+ 'slide-multipane.html',
97
+ slide_url=slide_url,
98
+ associated=associated_urls,
99
+ properties=app.slide_properties,
100
+ slide_mpp=app.slide_mpp,
101
+ )
102
+
103
+
104
+ @app.route('/<slug>.dzi')
105
+ def dzi(slug):
106
+ format = app.config['DEEPZOOM_FORMAT']
107
+ try:
108
+ resp = make_response(app.slides[slug].get_dzi(format))
109
+ resp.mimetype = 'application/xml'
110
+ return resp
111
+ except KeyError:
112
+ # Unknown slug
113
+ abort(404)
114
+
115
+
116
+ @app.route('/<slug>_files/<int:level>/<int:col>_<int:row>.<format>')
117
+ def tile(slug, level, col, row, format):
118
+ format = format.lower()
119
+ if format != 'jpeg' and format != 'png':
120
+ # Not supported by Deep Zoom
121
+ abort(404)
122
+ try:
123
+ tile = app.slides[slug].get_tile(level, (col, row))
124
+ except KeyError:
125
+ # Unknown slug
126
+ abort(404)
127
+ except ValueError:
128
+ # Invalid level or coordinates
129
+ abort(404)
130
+ buf = BytesIO()
131
+ tile.save(buf, format, quality=app.config['DEEPZOOM_TILE_QUALITY'])
132
+ resp = make_response(buf.getvalue())
133
+ resp.mimetype = 'image/%s' % format
134
+ return resp
135
+
136
+
137
+ def slugify(text):
138
+ text = normalize('NFKD', text.lower()).encode('ascii', 'ignore').decode()
139
+ return re.sub('[^a-z0-9]+', '-', text)
140
+
141
+
142
+ if __name__ == '__main__':
143
+ parser = OptionParser(usage='Usage: %prog [options] [slide]')
144
+ parser.add_option(
145
+ '-B',
146
+ '--ignore-bounds',
147
+ dest='DEEPZOOM_LIMIT_BOUNDS',
148
+ default=True,
149
+ action='store_false',
150
+ help='display entire scan area',
151
+ )
152
+ parser.add_option(
153
+ '-c', '--config', metavar='FILE', dest='config', help='config file'
154
+ )
155
+ parser.add_option(
156
+ '-d',
157
+ '--debug',
158
+ dest='DEBUG',
159
+ action='store_true',
160
+ help='run in debugging mode (insecure)',
161
+ )
162
+ parser.add_option(
163
+ '-e',
164
+ '--overlap',
165
+ metavar='PIXELS',
166
+ dest='DEEPZOOM_OVERLAP',
167
+ type='int',
168
+ help='overlap of adjacent tiles [1]',
169
+ )
170
+ parser.add_option(
171
+ '-f',
172
+ '--format',
173
+ metavar='{jpeg|png}',
174
+ dest='DEEPZOOM_FORMAT',
175
+ help='image format for tiles [jpeg]',
176
+ )
177
+ parser.add_option(
178
+ '-l',
179
+ '--listen',
180
+ metavar='ADDRESS',
181
+ dest='host',
182
+ default='127.0.0.1',
183
+ help='address to listen on [127.0.0.1]',
184
+ )
185
+ parser.add_option(
186
+ '-p',
187
+ '--port',
188
+ metavar='PORT',
189
+ dest='port',
190
+ type='int',
191
+ default=5000,
192
+ help='port to listen on [5000]',
193
+ )
194
+ parser.add_option(
195
+ '-Q',
196
+ '--quality',
197
+ metavar='QUALITY',
198
+ dest='DEEPZOOM_TILE_QUALITY',
199
+ type='int',
200
+ help='JPEG compression quality [75]',
201
+ )
202
+ parser.add_option(
203
+ '-s',
204
+ '--size',
205
+ metavar='PIXELS',
206
+ dest='DEEPZOOM_TILE_SIZE',
207
+ type='int',
208
+ help='tile size [254]',
209
+ )
210
+
211
+ (opts, args) = parser.parse_args()
212
+ # Load config file if specified
213
+ if opts.config is not None:
214
+ app.config.from_pyfile(opts.config)
215
+ # Overwrite only those settings specified on the command line
216
+ for k in dir(opts):
217
+ if not k.startswith('_') and getattr(opts, k) is None:
218
+ delattr(opts, k)
219
+ app.config.from_object(opts)
220
+ # Set slide file
221
+ try:
222
+ app.config['DEEPZOOM_SLIDE'] = args[0]
223
+ except IndexError:
224
+ if app.config['DEEPZOOM_SLIDE'] is None:
225
+ parser.error('No slide file specified')
226
+
227
+ app.run(host=opts.host, port=opts.port, threaded=True)
helper.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ from __future__ import absolute_import, division, print_function
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.autograd import Variable
12
+ from torchvision import transforms
13
+ from utils.metrics import ConfusionMatrix
14
+ from PIL import Image
15
+ import os
16
+
17
+ # torch.cuda.synchronize()
18
+ # torch.backends.cudnn.benchmark = True
19
+ torch.backends.cudnn.deterministic = True
20
+
21
+ def collate(batch):
22
+ image = [ b['image'] for b in batch ] # w, h
23
+ label = [ b['label'] for b in batch ]
24
+ id = [ b['id'] for b in batch ]
25
+ adj_s = [ b['adj_s'] for b in batch ]
26
+ return {'image': image, 'label': label, 'id': id, 'adj_s': adj_s}
27
+
28
+ def preparefeatureLabel(batch_graph, batch_label, batch_adjs, device='cpu'):
29
+ batch_size = len(batch_graph)
30
+ labels = torch.LongTensor(batch_size)
31
+ max_node_num = 0
32
+
33
+ for i in range(batch_size):
34
+ labels[i] = batch_label[i]
35
+ max_node_num = max(max_node_num, batch_graph[i].shape[0])
36
+
37
+ masks = torch.zeros(batch_size, max_node_num)
38
+ adjs = torch.zeros(batch_size, max_node_num, max_node_num)
39
+ batch_node_feat = torch.zeros(batch_size, max_node_num, 512)
40
+
41
+ for i in range(batch_size):
42
+ cur_node_num = batch_graph[i].shape[0]
43
+ #node attribute feature
44
+ tmp_node_fea = batch_graph[i]
45
+ batch_node_feat[i, 0:cur_node_num] = tmp_node_fea
46
+
47
+ #adjs
48
+ adjs[i, 0:cur_node_num, 0:cur_node_num] = batch_adjs[i]
49
+
50
+ #masks
51
+ masks[i,0:cur_node_num] = 1
52
+
53
+ node_feat = batch_node_feat.to(device)
54
+ labels = labels.to(device)
55
+ adjs = adjs.to(device)
56
+ masks = masks.to(device)
57
+
58
+ return node_feat, labels, adjs, masks
59
+
60
+ class Trainer(object):
61
+ def __init__(self, n_class):
62
+ self.metrics = ConfusionMatrix(n_class)
63
+
64
+ def get_scores(self):
65
+ acc = self.metrics.get_scores()
66
+
67
+ return acc
68
+
69
+ def reset_metrics(self):
70
+ self.metrics.reset()
71
+
72
+ def plot_cm(self):
73
+ self.metrics.plotcm()
74
+
75
+ def train(self, sample, model):
76
+ node_feat, labels, adjs, masks = preparefeatureLabel(sample['image'], sample['label'], sample['adj_s'])
77
+ pred,labels,loss = model.forward(node_feat, labels, adjs, masks)
78
+
79
+ return pred,labels,loss
80
+
81
+ class Evaluator(object):
82
+ def __init__(self, n_class):
83
+ self.metrics = ConfusionMatrix(n_class)
84
+
85
+ def get_scores(self):
86
+ acc = self.metrics.get_scores()
87
+
88
+ return acc
89
+
90
+ def reset_metrics(self):
91
+ self.metrics.reset()
92
+
93
+ def plot_cm(self):
94
+ self.metrics.plotcm()
95
+
96
+ def eval_test(self, sample, model, graphcam_flag=False):
97
+ node_feat, labels, adjs, masks = preparefeatureLabel(sample['image'], sample['label'], sample['adj_s'])
98
+ if not graphcam_flag:
99
+ with torch.no_grad():
100
+ pred,labels,loss = model.forward(node_feat, labels, adjs, masks)
101
+ else:
102
+ torch.set_grad_enabled(True)
103
+ pred,labels,loss= model.forward(node_feat, labels, adjs, masks, graphcam_flag=graphcam_flag)
104
+ return pred,labels,loss
main.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ from __future__ import absolute_import, division, print_function
5
+
6
+ import os
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from torchvision import transforms
11
+
12
+ from utils.dataset import GraphDataset
13
+ from utils.lr_scheduler import LR_Scheduler
14
+ from tensorboardX import SummaryWriter
15
+ from helper import Trainer, Evaluator, collate
16
+ from option import Options
17
+
18
+ from models.GraphTransformer import Classifier
19
+ from models.weight_init import weight_init
20
+ import pickle
21
+ args = Options().parse()
22
+
23
+ label_map = pickle.load(open(os.path.join(args.dataset_metadata_path, 'label_map.pkl'), 'rb'))
24
+
25
+ n_class = len(label_map)
26
+
27
+ torch.cuda.synchronize()
28
+ torch.backends.cudnn.deterministic = True
29
+
30
+ data_path = args.data_path
31
+ model_path = args.model_path
32
+ if not os.path.isdir(model_path): os.mkdir(model_path)
33
+ log_path = args.log_path
34
+ if not os.path.isdir(log_path): os.mkdir(log_path)
35
+ task_name = args.task_name
36
+
37
+ print(task_name)
38
+ ###################################
39
+ train = args.train
40
+ test = args.test
41
+ graphcam = args.graphcam
42
+ print("train:", train, "test:", test, "graphcam:", graphcam)
43
+
44
+ ##### Load datasets
45
+ print("preparing datasets and dataloaders......")
46
+ batch_size = args.batch_size
47
+
48
+ if train:
49
+ ids_train = open(args.train_set).readlines()
50
+ dataset_train = GraphDataset(os.path.join(data_path, ""), ids_train, args.dataset_metadata_path)
51
+ dataloader_train = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=batch_size, num_workers=10, collate_fn=collate, shuffle=True, pin_memory=True, drop_last=True)
52
+ total_train_num = len(dataloader_train) * batch_size
53
+
54
+ ids_val = open(args.val_set).readlines()
55
+ dataset_val = GraphDataset(os.path.join(data_path, ""), ids_val, args.dataset_metadata_path)
56
+ dataloader_val = torch.utils.data.DataLoader(dataset=dataset_val, batch_size=batch_size, num_workers=10, collate_fn=collate, shuffle=False, pin_memory=True)
57
+ total_val_num = len(dataloader_val) * batch_size
58
+
59
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
+ ##### creating models #############
61
+ print("creating models......")
62
+
63
+ num_epochs = args.num_epochs
64
+ learning_rate = args.lr
65
+
66
+ model = Classifier(n_class)
67
+ model = nn.DataParallel(model)
68
+ if args.resume:
69
+ print('load model{}'.format(args.resume))
70
+ model.load_state_dict(torch.load(args.resume))
71
+
72
+ if torch.cuda.is_available():
73
+ model = model.cuda()
74
+ #model.apply(weight_init)
75
+
76
+ optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate, weight_decay = 5e-4) # best:5e-4, 4e-3
77
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20,100], gamma=0.1) # gamma=0.3 # 30,90,130 # 20,90,130 -> 150
78
+
79
+ ##################################
80
+
81
+ criterion = nn.CrossEntropyLoss()
82
+
83
+ if not test:
84
+ writer = SummaryWriter(log_dir=log_path + task_name)
85
+ f_log = open(log_path + task_name + ".log", 'w')
86
+
87
+ trainer = Trainer(n_class)
88
+ evaluator = Evaluator(n_class)
89
+
90
+ best_pred = 0.0
91
+ for epoch in range(num_epochs):
92
+ # optimizer.zero_grad()
93
+ model.train()
94
+ train_loss = 0.
95
+ total = 0.
96
+
97
+ current_lr = optimizer.param_groups[0]['lr']
98
+ print('\n=>Epoches %i, learning rate = %.7f, previous best = %.4f' % (epoch+1, current_lr, best_pred))
99
+
100
+ if train:
101
+ for i_batch, sample_batched in enumerate(dataloader_train):
102
+ scheduler.step(epoch)
103
+
104
+ preds,labels,loss = trainer.train(sample_batched, model)
105
+
106
+ optimizer.zero_grad()
107
+ loss.backward()
108
+ optimizer.step()
109
+
110
+ train_loss += loss
111
+ total += len(labels)
112
+
113
+ trainer.metrics.update(labels, preds)
114
+ if (i_batch + 1) % args.log_interval_local == 0:
115
+ print("[%d/%d] train loss: %.3f; agg acc: %.3f" % (total, total_train_num, train_loss / total, trainer.get_scores()))
116
+ trainer.plot_cm()
117
+
118
+ if not test:
119
+ print("[%d/%d] train loss: %.3f; agg acc: %.3f" % (total_train_num, total_train_num, train_loss / total, trainer.get_scores()))
120
+ trainer.plot_cm()
121
+
122
+
123
+ if epoch % 1 == 0:
124
+ with torch.no_grad():
125
+ model.eval()
126
+ print("evaluating...")
127
+
128
+ total = 0.
129
+ batch_idx = 0
130
+
131
+ for i_batch, sample_batched in enumerate(dataloader_val):
132
+ preds, labels, _ = evaluator.eval_test(sample_batched, model, graphcam)
133
+
134
+ total += len(labels)
135
+
136
+ evaluator.metrics.update(labels, preds)
137
+
138
+ if (i_batch + 1) % args.log_interval_local == 0:
139
+ print('[%d/%d] val agg acc: %.3f' % (total, total_val_num, evaluator.get_scores()))
140
+ evaluator.plot_cm()
141
+
142
+ print('[%d/%d] val agg acc: %.3f' % (total_val_num, total_val_num, evaluator.get_scores()))
143
+ evaluator.plot_cm()
144
+
145
+ # torch.cuda.empty_cache()
146
+
147
+ val_acc = evaluator.get_scores()
148
+ if val_acc > best_pred:
149
+ best_pred = val_acc
150
+ if not test:
151
+ print("saving model...")
152
+ torch.save(model.state_dict(), model_path + task_name + ".pth")
153
+
154
+ log = ""
155
+ log = log + 'epoch [{}/{}] ------ acc: train = {:.4f}, val = {:.4f}'.format(epoch+1, num_epochs, trainer.get_scores(), evaluator.get_scores()) + "\n"
156
+
157
+ log += "================================\n"
158
+ print(log)
159
+ if test: break
160
+
161
+ f_log.write(log)
162
+ f_log.flush()
163
+
164
+ writer.add_scalars('accuracy', {'train acc': trainer.get_scores(), 'val acc': evaluator.get_scores()}, epoch+1)
165
+
166
+ trainer.reset_metrics()
167
+ evaluator.reset_metrics()
168
+
169
+ if not test: f_log.close()
metadata/label_map.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce5be416a8667c9379502eaf8407e6d07bbae03749085190be630bd3b026eb52
3
+ size 34
models/.gitkeep ADDED
@@ -0,0 +1 @@
 
 
1
+
models/GraphTransformer.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import torch
4
+ import random
5
+ import numpy as np
6
+
7
+ from torch.autograd import Variable
8
+ from torch.nn.parameter import Parameter
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.optim as optim
12
+
13
+ from .ViT import *
14
+ from .gcn import GCNBlock
15
+
16
+ from torch_geometric.nn import GCNConv, DenseGraphConv, dense_mincut_pool
17
+ from torch.nn import Linear
18
+ class Classifier(nn.Module):
19
+ def __init__(self, n_class):
20
+ super(Classifier, self).__init__()
21
+
22
+ self.n_class = n_class
23
+ self.embed_dim = 64
24
+ self.num_layers = 3
25
+ self.node_cluster_num = 100
26
+
27
+ self.transformer = VisionTransformer(num_classes=n_class, embed_dim=self.embed_dim)
28
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
29
+ self.criterion = nn.CrossEntropyLoss()
30
+
31
+ self.bn = 1
32
+ self.add_self = 1
33
+ self.normalize_embedding = 1
34
+ self.conv1 = GCNBlock(512,self.embed_dim,self.bn,self.add_self,self.normalize_embedding,0.,0) # 64->128
35
+ self.pool1 = Linear(self.embed_dim, self.node_cluster_num) # 100-> 20
36
+
37
+
38
+ def forward(self,node_feat,labels,adj,mask,is_print=False, graphcam_flag=False, to_file=True):
39
+ # node_feat, labels = self.PrepareFeatureLabel(batch_graph)
40
+ cls_loss=node_feat.new_zeros(self.num_layers)
41
+ rank_loss=node_feat.new_zeros(self.num_layers-1)
42
+ X=node_feat
43
+ p_t=[]
44
+ pred_logits=0
45
+ visualize_tools=[]
46
+ if labels is not None:
47
+ visualize_tools1=[labels.cpu()]
48
+ embeds=0
49
+ concats=[]
50
+
51
+ layer_acc=[]
52
+
53
+ X=mask.unsqueeze(2)*X
54
+ X = self.conv1(X, adj, mask)
55
+ s = self.pool1(X)
56
+
57
+
58
+ graphcam_tensors = {}
59
+
60
+ if graphcam_flag:
61
+ s_matrix = torch.argmax(s[0], dim=1)
62
+ if to_file:
63
+ from os import path
64
+ os.makedirs('graphcam', exist_ok=True)
65
+ torch.save(s_matrix, 'graphcam/s_matrix.pt')
66
+ torch.save(s[0], 'graphcam/s_matrix_ori.pt')
67
+
68
+ if path.exists('graphcam/att_1.pt'):
69
+ os.remove('graphcam/att_1.pt')
70
+ os.remove('graphcam/att_2.pt')
71
+ os.remove('graphcam/att_3.pt')
72
+
73
+ if not to_file:
74
+ graphcam_tensors['s_matrix'] = s_matrix
75
+ graphcam_tensors['s_matrix_ori'] = s[0]
76
+
77
+
78
+ X, adj, mc1, o1 = dense_mincut_pool(X, adj, s, mask)
79
+ b, _, _ = X.shape
80
+ cls_token = self.cls_token.repeat(b, 1, 1)
81
+ X = torch.cat([cls_token, X], dim=1)
82
+
83
+ out = self.transformer(X)
84
+
85
+ loss = None
86
+ if labels is not None:
87
+ # loss
88
+ loss = self.criterion(out, labels)
89
+ loss = loss + mc1 + o1
90
+ # pred
91
+ pred = out.data.max(1)[1]
92
+
93
+ if graphcam_flag:
94
+ #print('GraphCAM enabled')
95
+ #print(out.shape)
96
+ p = F.softmax(out)
97
+ #print(p.shape)
98
+ if to_file:
99
+ torch.save(p, 'graphcam/prob.pt')
100
+ if not to_file:
101
+ graphcam_tensors['prob'] = p
102
+ index = np.argmax(out.cpu().data.numpy(), axis=-1)
103
+
104
+ for index_ in range(self.n_class):
105
+ one_hot = np.zeros((1, out.size()[-1]), dtype=np.float32)
106
+ one_hot[0, index_] = out[0][index_]
107
+ one_hot_vector = one_hot
108
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
109
+ one_hot = torch.sum(one_hot.to( 'cuda' if torch.cuda.is_available() else 'cpu') * out) #!!!!!!!!!!!!!!!!!!!!out-->p
110
+ self.transformer.zero_grad()
111
+ one_hot.backward(retain_graph=True)
112
+
113
+ kwargs = {"alpha": 1}
114
+ cam = self.transformer.relprop(torch.tensor(one_hot_vector).to(X.device), method="transformer_attribution", is_ablation=False,
115
+ start_layer=0, **kwargs)
116
+ if to_file:
117
+ torch.save(cam, 'graphcam/cam_{}.pt'.format(index_))
118
+ if not to_file:
119
+ graphcam_tensors[f'cam_{index_}'] = cam
120
+
121
+ if not to_file:
122
+ return pred,labels,loss, graphcam_tensors
123
+ return pred,labels,loss
models/ViT.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+ """
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from .layers import *
7
+ import math
8
+
9
+
10
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
11
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
12
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
13
+ def norm_cdf(x):
14
+ # Computes standard normal cumulative distribution function
15
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
16
+
17
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
18
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
19
+ "The distribution of values may be incorrect.",
20
+ stacklevel=2)
21
+
22
+ with torch.no_grad():
23
+ # Values are generated by using a truncated uniform distribution and
24
+ # then using the inverse CDF for the normal distribution.
25
+ # Get upper and lower cdf values
26
+ l = norm_cdf((a - mean) / std)
27
+ u = norm_cdf((b - mean) / std)
28
+
29
+ # Uniformly fill tensor with values from [l, u], then translate to
30
+ # [2l-1, 2u-1].
31
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
32
+
33
+ # Use inverse cdf transform for normal distribution to get truncated
34
+ # standard normal
35
+ tensor.erfinv_()
36
+
37
+ # Transform to proper mean, std
38
+ tensor.mul_(std * math.sqrt(2.))
39
+ tensor.add_(mean)
40
+
41
+ # Clamp to ensure it's in the proper range
42
+ tensor.clamp_(min=a, max=b)
43
+ return tensor
44
+
45
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
46
+ # type: (Tensor, float, float, float, float) -> Tensor
47
+ r"""Fills the input Tensor with values drawn from a truncated
48
+ normal distribution. The values are effectively drawn from the
49
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
50
+ with values outside :math:`[a, b]` redrawn until they are within
51
+ the bounds. The method used for generating the random values works
52
+ best when :math:`a \leq \text{mean} \leq b`.
53
+ Args:
54
+ tensor: an n-dimensional `torch.Tensor`
55
+ mean: the mean of the normal distribution
56
+ std: the standard deviation of the normal distribution
57
+ a: the minimum cutoff value
58
+ b: the maximum cutoff value
59
+ Examples:
60
+ >>> w = torch.empty(3, 5)
61
+ >>> nn.init.trunc_normal_(w)
62
+ """
63
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
64
+
65
+ def _cfg(url='', **kwargs):
66
+ return {
67
+ 'url': url,
68
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
69
+ 'crop_pct': .9, 'interpolation': 'bicubic',
70
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
71
+ **kwargs
72
+ }
73
+
74
+
75
+ default_cfgs = {
76
+ # patch models
77
+ 'vit_small_patch16_224': _cfg(
78
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
79
+ ),
80
+ 'vit_base_patch16_224': _cfg(
81
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
82
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
83
+ ),
84
+ 'vit_large_patch16_224': _cfg(
85
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
86
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
87
+ }
88
+
89
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
90
+ # adding residual consideration
91
+ num_tokens = all_layer_matrices[0].shape[1]
92
+ batch_size = all_layer_matrices[0].shape[0]
93
+ eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
94
+ all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
95
+ # all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
96
+ # for i in range(len(all_layer_matrices))]
97
+ joint_attention = all_layer_matrices[start_layer]
98
+ for i in range(start_layer+1, len(all_layer_matrices)):
99
+ joint_attention = all_layer_matrices[i].bmm(joint_attention)
100
+ return joint_attention
101
+
102
+ class Mlp(nn.Module):
103
+ def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
104
+ super().__init__()
105
+ out_features = out_features or in_features
106
+ hidden_features = hidden_features or in_features
107
+ self.fc1 = Linear(in_features, hidden_features)
108
+ self.act = GELU()
109
+ self.fc2 = Linear(hidden_features, out_features)
110
+ self.drop = Dropout(drop)
111
+
112
+ def forward(self, x):
113
+ x = self.fc1(x)
114
+ x = self.act(x)
115
+ x = self.drop(x)
116
+ x = self.fc2(x)
117
+ x = self.drop(x)
118
+ return x
119
+
120
+ def relprop(self, cam, **kwargs):
121
+ cam = self.drop.relprop(cam, **kwargs)
122
+ cam = self.fc2.relprop(cam, **kwargs)
123
+ cam = self.act.relprop(cam, **kwargs)
124
+ cam = self.fc1.relprop(cam, **kwargs)
125
+ return cam
126
+
127
+
128
+ class Attention(nn.Module):
129
+ def __init__(self, dim, num_heads=8, qkv_bias=False,attn_drop=0., proj_drop=0.):
130
+ super().__init__()
131
+ self.num_heads = num_heads
132
+ head_dim = dim // num_heads
133
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
134
+ self.scale = head_dim ** -0.5
135
+
136
+ # A = Q*K^T
137
+ self.matmul1 = einsum('bhid,bhjd->bhij')
138
+ # attn = A*V
139
+ self.matmul2 = einsum('bhij,bhjd->bhid')
140
+
141
+ self.qkv = Linear(dim, dim * 3, bias=qkv_bias)
142
+ self.attn_drop = Dropout(attn_drop)
143
+ self.proj = Linear(dim, dim)
144
+ self.proj_drop = Dropout(proj_drop)
145
+ self.softmax = Softmax(dim=-1)
146
+
147
+ self.attn_cam = None
148
+ self.attn = None
149
+ self.v = None
150
+ self.v_cam = None
151
+ self.attn_gradients = None
152
+
153
+ def get_attn(self):
154
+ return self.attn
155
+
156
+ def save_attn(self, attn):
157
+ self.attn = attn
158
+
159
+ def save_attn_cam(self, cam):
160
+ self.attn_cam = cam
161
+
162
+ def get_attn_cam(self):
163
+ return self.attn_cam
164
+
165
+ def get_v(self):
166
+ return self.v
167
+
168
+ def save_v(self, v):
169
+ self.v = v
170
+
171
+ def save_v_cam(self, cam):
172
+ self.v_cam = cam
173
+
174
+ def get_v_cam(self):
175
+ return self.v_cam
176
+
177
+ def save_attn_gradients(self, attn_gradients):
178
+ self.attn_gradients = attn_gradients
179
+
180
+ def get_attn_gradients(self):
181
+ return self.attn_gradients
182
+
183
+ def forward(self, x):
184
+ b, n, _, h = *x.shape, self.num_heads
185
+ qkv = self.qkv(x)
186
+ q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)
187
+
188
+ self.save_v(v)
189
+
190
+ dots = self.matmul1([q, k]) * self.scale
191
+
192
+ attn = self.softmax(dots)
193
+ attn = self.attn_drop(attn)
194
+
195
+ # Get attention
196
+ if False:
197
+ from os import path
198
+ if not path.exists('att_1.pt'):
199
+ torch.save(attn, 'att_1.pt')
200
+ elif not path.exists('att_2.pt'):
201
+ torch.save(attn, 'att_2.pt')
202
+ else:
203
+ torch.save(attn, 'att_3.pt')
204
+
205
+ #comment in training
206
+ if x.requires_grad:
207
+ self.save_attn(attn)
208
+ attn.register_hook(self.save_attn_gradients)
209
+
210
+ out = self.matmul2([attn, v])
211
+ out = rearrange(out, 'b h n d -> b n (h d)')
212
+
213
+ out = self.proj(out)
214
+ out = self.proj_drop(out)
215
+ return out
216
+
217
+ def relprop(self, cam, **kwargs):
218
+ cam = self.proj_drop.relprop(cam, **kwargs)
219
+ cam = self.proj.relprop(cam, **kwargs)
220
+ cam = rearrange(cam, 'b n (h d) -> b h n d', h=self.num_heads)
221
+
222
+ # attn = A*V
223
+ (cam1, cam_v)= self.matmul2.relprop(cam, **kwargs)
224
+ cam1 /= 2
225
+ cam_v /= 2
226
+
227
+ self.save_v_cam(cam_v)
228
+ self.save_attn_cam(cam1)
229
+
230
+ cam1 = self.attn_drop.relprop(cam1, **kwargs)
231
+ cam1 = self.softmax.relprop(cam1, **kwargs)
232
+
233
+ # A = Q*K^T
234
+ (cam_q, cam_k) = self.matmul1.relprop(cam1, **kwargs)
235
+ cam_q /= 2
236
+ cam_k /= 2
237
+
238
+ cam_qkv = rearrange([cam_q, cam_k, cam_v], 'qkv b h n d -> b n (qkv h d)', qkv=3, h=self.num_heads)
239
+
240
+ return self.qkv.relprop(cam_qkv, **kwargs)
241
+
242
+
243
+ class Block(nn.Module):
244
+
245
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
246
+ super().__init__()
247
+ self.norm1 = LayerNorm(dim, eps=1e-6)
248
+ self.attn = Attention(
249
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
250
+ self.norm2 = LayerNorm(dim, eps=1e-6)
251
+ mlp_hidden_dim = int(dim * mlp_ratio)
252
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
253
+
254
+ self.add1 = Add()
255
+ self.add2 = Add()
256
+ self.clone1 = Clone()
257
+ self.clone2 = Clone()
258
+
259
+ def forward(self, x):
260
+ x1, x2 = self.clone1(x, 2)
261
+ x = self.add1([x1, self.attn(self.norm1(x2))])
262
+ x1, x2 = self.clone2(x, 2)
263
+ x = self.add2([x1, self.mlp(self.norm2(x2))])
264
+ return x
265
+
266
+ def relprop(self, cam, **kwargs):
267
+ (cam1, cam2) = self.add2.relprop(cam, **kwargs)
268
+ cam2 = self.mlp.relprop(cam2, **kwargs)
269
+ cam2 = self.norm2.relprop(cam2, **kwargs)
270
+ cam = self.clone2.relprop((cam1, cam2), **kwargs)
271
+
272
+ (cam1, cam2) = self.add1.relprop(cam, **kwargs)
273
+ cam2 = self.attn.relprop(cam2, **kwargs)
274
+ cam2 = self.norm1.relprop(cam2, **kwargs)
275
+ cam = self.clone1.relprop((cam1, cam2), **kwargs)
276
+ return cam
277
+
278
+ class VisionTransformer(nn.Module):
279
+ """ Vision Transformer with support for patch or hybrid CNN input stage
280
+ """
281
+ def __init__(self, num_classes=2, embed_dim=64, depth=3,
282
+ num_heads=8, mlp_ratio=2., qkv_bias=False, mlp_head=False, drop_rate=0., attn_drop_rate=0.):
283
+ super().__init__()
284
+ self.num_classes = num_classes
285
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
286
+
287
+ self.blocks = nn.ModuleList([
288
+ Block(
289
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
290
+ drop=drop_rate, attn_drop=attn_drop_rate)
291
+ for i in range(depth)])
292
+
293
+ self.norm = LayerNorm(embed_dim)
294
+ if mlp_head:
295
+ # paper diagram suggests 'MLP head', but results in 4M extra parameters vs paper
296
+ self.head = Mlp(embed_dim, int(embed_dim * mlp_ratio), num_classes)
297
+ else:
298
+ # with a single Linear layer as head, the param count within rounding of paper
299
+ self.head = Linear(embed_dim, num_classes)
300
+
301
+ #self.apply(self._init_weights)
302
+
303
+ self.pool = IndexSelect()
304
+ self.add = Add()
305
+
306
+ self.inp_grad = None
307
+
308
+ def save_inp_grad(self,grad):
309
+ self.inp_grad = grad
310
+
311
+ def get_inp_grad(self):
312
+ return self.inp_grad
313
+
314
+
315
+ def _init_weights(self, m):
316
+ if isinstance(m, nn.Linear):
317
+ trunc_normal_(m.weight, std=.02)
318
+ if isinstance(m, nn.Linear) and m.bias is not None:
319
+ nn.init.constant_(m.bias, 0)
320
+ elif isinstance(m, nn.LayerNorm):
321
+ nn.init.constant_(m.bias, 0)
322
+ nn.init.constant_(m.weight, 1.0)
323
+
324
+ @property
325
+ def no_weight_decay(self):
326
+ return {'pos_embed', 'cls_token'}
327
+
328
+ def forward(self, x):
329
+ if x.requires_grad:
330
+ x.register_hook(self.save_inp_grad) #comment it in train
331
+
332
+ for blk in self.blocks:
333
+ x = blk(x)
334
+
335
+ x = self.norm(x)
336
+ x = self.pool(x, dim=1, indices=torch.tensor(0, device=x.device))
337
+ x = x.squeeze(1)
338
+ x = self.head(x)
339
+ return x
340
+
341
+ def relprop(self, cam=None,method="transformer_attribution", is_ablation=False, start_layer=0, **kwargs):
342
+ # print(kwargs)
343
+ # print("conservation 1", cam.sum())
344
+ cam = self.head.relprop(cam, **kwargs)
345
+ cam = cam.unsqueeze(1)
346
+ cam = self.pool.relprop(cam, **kwargs)
347
+ cam = self.norm.relprop(cam, **kwargs)
348
+ for blk in reversed(self.blocks):
349
+ cam = blk.relprop(cam, **kwargs)
350
+
351
+ # print("conservation 2", cam.sum())
352
+ # print("min", cam.min())
353
+
354
+ if method == "full":
355
+ (cam, _) = self.add.relprop(cam, **kwargs)
356
+ cam = cam[:, 1:]
357
+ cam = self.patch_embed.relprop(cam, **kwargs)
358
+ # sum on channels
359
+ cam = cam.sum(dim=1)
360
+ return cam
361
+
362
+ elif method == "rollout":
363
+ # cam rollout
364
+ attn_cams = []
365
+ for blk in self.blocks:
366
+ attn_heads = blk.attn.get_attn_cam().clamp(min=0)
367
+ avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
368
+ attn_cams.append(avg_heads)
369
+ cam = compute_rollout_attention(attn_cams, start_layer=start_layer)
370
+ cam = cam[:, 0, 1:]
371
+ return cam
372
+
373
+ # our method, method name grad is legacy
374
+ elif method == "transformer_attribution" or method == "grad":
375
+ cams = []
376
+ for blk in self.blocks:
377
+ grad = blk.attn.get_attn_gradients()
378
+ cam = blk.attn.get_attn_cam()
379
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
380
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
381
+ cam = grad * cam
382
+ cam = cam.clamp(min=0).mean(dim=0)
383
+ cams.append(cam.unsqueeze(0))
384
+ rollout = compute_rollout_attention(cams, start_layer=start_layer)
385
+ cam = rollout[:, 0, 1:]
386
+ return cam
387
+
388
+ elif method == "last_layer":
389
+ cam = self.blocks[-1].attn.get_attn_cam()
390
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
391
+ if is_ablation:
392
+ grad = self.blocks[-1].attn.get_attn_gradients()
393
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
394
+ cam = grad * cam
395
+ cam = cam.clamp(min=0).mean(dim=0)
396
+ cam = cam[0, 1:]
397
+ return cam
398
+
399
+ elif method == "last_layer_attn":
400
+ cam = self.blocks[-1].attn.get_attn()
401
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
402
+ cam = cam.clamp(min=0).mean(dim=0)
403
+ cam = cam[0, 1:]
404
+ return cam
405
+
406
+ elif method == "second_layer":
407
+ cam = self.blocks[1].attn.get_attn_cam()
408
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
409
+ if is_ablation:
410
+ grad = self.blocks[1].attn.get_attn_gradients()
411
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
412
+ cam = grad * cam
413
+ cam = cam.clamp(min=0).mean(dim=0)
414
+ cam = cam[0, 1:]
415
+ return cam
models/__init__.py ADDED
File without changes
models/__pycache__/GraphTransformer.cpython-38.pyc ADDED
Binary file (3.35 kB). View file
 
models/__pycache__/ViT.cpython-38.pyc ADDED
Binary file (12.5 kB). View file
 
models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (168 Bytes). View file
 
models/__pycache__/gcn.cpython-38.pyc ADDED
Binary file (9.61 kB). View file
 
models/__pycache__/layers.cpython-38.pyc ADDED
Binary file (9.93 kB). View file
 
models/__pycache__/weight_init.cpython-38.pyc ADDED
Binary file (1.72 kB). View file
 
models/gcn.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ import torch.nn.functional as F
5
+ import math
6
+
7
+ import numpy as np
8
+
9
+ torch.set_printoptions(precision=2,threshold=float('inf'))
10
+
11
+ class AGCNBlock(nn.Module):
12
+ def __init__(self,input_dim,hidden_dim,gcn_layer=2,dropout=0.0,relu=0):
13
+ super(AGCNBlock,self).__init__()
14
+ if dropout > 0.001:
15
+ self.dropout_layer = nn.Dropout(p=dropout)
16
+ self.sort = 'sort'
17
+ self.model='agcn'
18
+ self.gcns=nn.ModuleList()
19
+ self.bn = 0
20
+ self.add_self = 1
21
+ self.normalize_embedding = 1
22
+ self.gcns.append(GCNBlock(input_dim,hidden_dim,self.bn,self.add_self,self.normalize_embedding,dropout,relu))
23
+ self.pool = 'mean'
24
+ self.tau = 1.
25
+ self.lamda = 1.
26
+
27
+ for i in range(gcn_layer-1):
28
+ if i==gcn_layer-2 and (not 1):
29
+ self.gcns.append(GCNBlock(hidden_dim,hidden_dim,self.bn,self.add_self,self.normalize_embedding,dropout,0))
30
+ else:
31
+ self.gcns.append(GCNBlock(hidden_dim,hidden_dim,self.bn,self.add_self,self.normalize_embedding,dropout,relu))
32
+
33
+ if self.model=='diffpool':
34
+ self.pool_gcns=nn.ModuleList()
35
+ tmp=input_dim
36
+ self.diffpool_k=200
37
+ for i in range(3):
38
+ self.pool_gcns.append(GCNBlock(tmp,200,0,0,0,dropout,relu))
39
+ tmp=200
40
+
41
+ self.w_a=nn.Parameter(torch.zeros(1,hidden_dim,1))
42
+ self.w_b=nn.Parameter(torch.zeros(1,hidden_dim,1))
43
+ torch.nn.init.normal_(self.w_a)
44
+ torch.nn.init.uniform_(self.w_b,-1,1)
45
+
46
+ self.pass_dim=hidden_dim
47
+
48
+ if self.pool=='mean':
49
+ self.pool=self.mean_pool
50
+ elif self.pool=='max':
51
+ self.pool=self.max_pool
52
+ elif self.pool=='sum':
53
+ self.pool=self.sum_pool
54
+
55
+ self.softmax='global'
56
+ if self.softmax=='gcn':
57
+ self.att_gcn=GCNBlock(2,1,0,0,dropout,relu)
58
+ self.khop=1
59
+ self.adj_norm='none'
60
+
61
+ self.filt_percent=0.25 #default 0.5
62
+ self.eps=1e-10
63
+
64
+ self.tau_config=1
65
+ if 1==-1.:
66
+ self.tau=nn.Parameter(torch.tensor(1),requires_grad=False)
67
+ elif 1==-2.:
68
+ self.tau_fc=nn.Linear(hidden_dim,1)
69
+ torch.nn.init.constant_(self.tau_fc.bias,1)
70
+ torch.nn.init.xavier_normal_(self.tau_fc.weight.t())
71
+ else:
72
+ self.tau=nn.Parameter(torch.tensor(self.tau))
73
+ self.lamda1=nn.Parameter(torch.tensor(self.lamda))
74
+ self.lamda2=nn.Parameter(torch.tensor(self.lamda))
75
+
76
+ self.att_norm=0
77
+
78
+ self.dnorm=0
79
+ self.dnorm_coe=1
80
+
81
+ self.att_out=0
82
+ self.single_att=0
83
+
84
+
85
+ def forward(self,X,adj,mask,is_print=False):
86
+ '''
87
+ input:
88
+ X: node input features , [batch,node_num,input_dim],dtype=float
89
+ adj: adj matrix, [batch,node_num,node_num], dtype=float
90
+ mask: mask for nodes, [batch,node_num]
91
+ outputs:
92
+ out:unormalized classification prob, [batch,hidden_dim]
93
+ H: batch of node hidden features, [batch,node_num,pass_dim]
94
+ new_adj: pooled new adj matrix, [batch, k_max, k_max]
95
+ new_mask: [batch, k_max]
96
+ '''
97
+ hidden=X
98
+ #adj = adj.float()
99
+ # print('input size:')
100
+ # print(hidden.shape)
101
+
102
+ is_print1=is_print2=is_print
103
+ if adj.shape[-1]>100:
104
+ is_print1=False
105
+
106
+ for gcn in self.gcns:
107
+ hidden=gcn(hidden,adj,mask)
108
+ # print('gcn:')
109
+ # print(hidden.shape)
110
+ # print('mask:')
111
+ # print(mask.unsqueeze(2).shape)
112
+ # print(mask.sum(dim=1))
113
+
114
+ hidden=mask.unsqueeze(2)*hidden
115
+ # print(hidden[0][0])
116
+ # print(hidden[0][-1])
117
+
118
+ if self.model=='unet':
119
+ att=torch.matmul(hidden,self.w_a).squeeze()
120
+ att=att/torch.sqrt((self.w_a.squeeze(2)**2).sum(dim=1,keepdim=True))
121
+ elif self.model=='agcn':
122
+ if self.softmax=='global' or self.softmax=='mix':
123
+ if False:
124
+ dgree_w = torch.sum(adj, dim=2) / torch.sum(adj, dim=2).max(1, keepdim=True)[0]
125
+ att_a=torch.matmul(hidden,self.w_a).squeeze()*dgree_w+(mask-1)*1e10
126
+ else:
127
+ att_a=torch.matmul(hidden,self.w_a).squeeze()+(mask-1)*1e10
128
+ # print(att_a[0][:10])
129
+ # print(att_a[0][-10:-1])
130
+ att_a_1=att_a=torch.nn.functional.softmax(att_a,dim=1)
131
+ # print(att_a[0][:10])
132
+ # print(att_a[0][-10:-1])
133
+
134
+ if self.dnorm:
135
+ scale=mask.sum(dim=1,keepdim=True)/self.dnorm_coe
136
+ att_a=scale*att_a
137
+ if self.softmax=='neibor' or self.softmax=='mix':
138
+ att_b=torch.matmul(hidden,self.w_b).squeeze()+(mask-1)*1e10
139
+ att_b_max,_=att_b.max(dim=1,keepdim=True)
140
+ if self.tau_config!=-2:
141
+ att_b=torch.exp((att_b-att_b_max)*torch.abs(self.tau))
142
+ else:
143
+ att_b=torch.exp((att_b-att_b_max)*torch.abs(self.tau_fc(self.pool(hidden,mask))))
144
+ denom=att_b.unsqueeze(2)
145
+ for _ in range(self.khop):
146
+ denom=torch.matmul(adj,denom)
147
+ denom=denom.squeeze()+self.eps
148
+ att_b=(att_b*torch.diagonal(adj,0,1,2))/denom
149
+ if self.dnorm:
150
+ if self.adj_norm=='diag':
151
+ diag_scale=mask/(torch.diagonal(adj,0,1,2)+self.eps)
152
+ elif self.adj_norm=='none':
153
+ diag_scale=adj.sum(dim=1)
154
+ att_b=att_b*diag_scale
155
+ att_b=att_b*mask
156
+
157
+ if self.softmax=='global':
158
+ att=att_a
159
+ elif self.softmax=='neibor' or self.softmax=='hardnei':
160
+ att=att_b
161
+ elif self.softmax=='mix':
162
+ att=att_a*torch.abs(self.lamda1)+att_b*torch.abs(self.lamda2)
163
+ # print('att:')
164
+ # print(att.shape)
165
+ Z=hidden
166
+
167
+ if self.model=='unet':
168
+ Z=torch.tanh(att.unsqueeze(2))*Z
169
+ elif self.model=='agcn':
170
+ if self.single_att:
171
+ Z=Z
172
+ else:
173
+ Z=att.unsqueeze(2)*Z
174
+ # print('Z shape')
175
+ # print(Z.shape)
176
+ k_max=int(math.ceil(self.filt_percent*adj.shape[-1]))
177
+ # print('k_max')
178
+ # print(k_max)
179
+ if self.model=='diffpool':
180
+ k_max=min(k_max,self.diffpool_k)
181
+
182
+ k_list=[int(math.ceil(self.filt_percent*x)) for x in mask.sum(dim=1).tolist()]
183
+ # print('k_list')
184
+ # print(k_list)
185
+ if self.model!='diffpool':
186
+ if self.sort=='sample':
187
+ att_samp = att * mask
188
+ att_samp = (att_samp/att_samp.sum(1)).detach().cpu().numpy()
189
+ top_index = ()
190
+ for i in range(att.size(0)):
191
+ top_index = (torch.LongTensor(np.random.choice(att_samp.size(1), k_max, att_samp[i])) ,)
192
+ top_index = torch.stack(top_index,1)
193
+ elif self.sort=='random_sample':
194
+ top_index = torch.LongTensor(att.size(0), k_max)*0
195
+ for i in range(att.size(0)):
196
+ top_index[i,0:k_list[i]] = torch.randperm(int(mask[i].sum().item()))[0:k_list[i]]
197
+ else: #sort
198
+ _,top_index=torch.topk(att,k_max,dim=1)
199
+ # print('top_index')
200
+ # print(top_index)
201
+ # print(len(top_index[0]))
202
+ new_mask=X.new_zeros(X.shape[0],k_max)
203
+ # print('new_mask')
204
+ # print(new_mask.shape)
205
+ visualize_tools=None
206
+ if self.model=='unet':
207
+ for i,k in enumerate(k_list):
208
+ for j in range(int(k),k_max):
209
+ top_index[i][j]=adj.shape[-1]-1
210
+ new_mask[i][j]=-1.
211
+ new_mask=new_mask+1
212
+ top_index,_=torch.sort(top_index,dim=1)
213
+ assign_m=X.new_zeros(X.shape[0],k_max,adj.shape[-1])
214
+ for i,x in enumerate(top_index):
215
+ assign_m[i]=torch.index_select(adj[i],0,x)
216
+ new_adj=X.new_zeros(X.shape[0],k_max,k_max)
217
+ H=Z.new_zeros(Z.shape[0],k_max,Z.shape[-1])
218
+ for i,x in enumerate(top_index):
219
+ new_adj[i]=torch.index_select(assign_m[i],1,x)
220
+ H[i]=torch.index_select(Z[i],0,x)
221
+
222
+ elif self.model=='agcn':
223
+ assign_m=X.new_zeros(X.shape[0],k_max,adj.shape[-1])
224
+ # print('assign_m.shape')
225
+ # print(assign_m.shape)
226
+ for i,k in enumerate(k_list):
227
+ #print('top_index[i][j]')
228
+ for j in range(int(k)):
229
+ #print(str(top_index[i][j].item())+' ', end='')
230
+ assign_m[i][j]=adj[i][top_index[i][j]]
231
+ #print(assign_m[i][j])
232
+ new_mask[i][j]=1.
233
+
234
+ assign_m=assign_m/(assign_m.sum(dim=1,keepdim=True)+self.eps)
235
+ H=torch.matmul(assign_m,Z)
236
+ # print('H')
237
+ # print(H.shape)
238
+ new_adj=torch.matmul(torch.matmul(assign_m,adj),torch.transpose(assign_m,1,2))
239
+ # print(torch.matmul(assign_m,adj).shape)
240
+ # print('new_adj:')
241
+ # print(new_adj.shape)
242
+
243
+ elif self.model=='diffpool':
244
+ hidden1=X
245
+ for gcn in self.pool_gcns:
246
+ hidden1=gcn(hidden1,adj,mask)
247
+ assign_m=X.new_ones(X.shape[0],X.shape[1],k_max)*(-100000000.)
248
+ for i,x in enumerate(hidden1):
249
+ k=min(k_list[i],k_max)
250
+ assign_m[i,:,0:k]=hidden1[i,:,0:k]
251
+ for j in range(int(k)):
252
+ new_mask[i][j]=1.
253
+
254
+ assign_m=torch.nn.functional.softmax(assign_m,dim=2)*mask.unsqueeze(2)
255
+ assign_m_t=torch.transpose(assign_m,1,2)
256
+ new_adj=torch.matmul(torch.matmul(assign_m_t,adj),assign_m)
257
+ H=torch.matmul(assign_m_t,Z)
258
+ # print('pool')
259
+ if self.att_out and self.model=='agcn':
260
+ if self.softmax=='global':
261
+ out=self.pool(att_a_1.unsqueeze(2)*hidden,mask)
262
+ elif self.softmax=='neibor':
263
+ att_b_sum=att_b.sum(dim=1,keepdim=True)
264
+ out=self.pool((att_b/(att_b_sum+self.eps)).unsqueeze(2)*hidden,mask)
265
+ else:
266
+ # print('hidden.shape')
267
+ # print(hidden.shape)
268
+ out=self.pool(hidden,mask)
269
+ # print('out shape')
270
+ # print(out.shape)
271
+
272
+ if self.adj_norm=='tanh' or self.adj_norm=='mix':
273
+ new_adj=torch.tanh(new_adj)
274
+ elif self.adj_norm=='diag' or self.adj_norm=='mix':
275
+ diag_elem=torch.pow(new_adj.sum(dim=2)+self.eps,-0.5)
276
+ diag=new_adj.new_zeros(new_adj.shape)
277
+ for i,x in enumerate(diag_elem):
278
+ diag[i]=torch.diagflat(x)
279
+ new_adj=torch.matmul(torch.matmul(diag,new_adj),diag)
280
+
281
+ visualize_tools=[]
282
+ '''
283
+ if (not self.training) and is_print1:
284
+ print('**********************************')
285
+ print('node_feat:',X.type(),X.shape)
286
+ print(X)
287
+ if self.model!='diffpool':
288
+ print('**********************************')
289
+ print('att:',att.type(),att.shape)
290
+ print(att)
291
+ print('**********************************')
292
+ print('top_index:',top_index.type(),top_index.shape)
293
+ print(top_index)
294
+ print('**********************************')
295
+ print('adj:',adj.type(),adj.shape)
296
+ print(adj)
297
+ print('**********************************')
298
+ print('assign_m:',assign_m.type(),assign_m.shape)
299
+ print(assign_m)
300
+ print('**********************************')
301
+ print('new_adj:',new_adj.type(),new_adj.shape)
302
+ print(new_adj)
303
+ print('**********************************')
304
+ print('new_mask:',new_mask.type(),new_mask.shape)
305
+ print(new_mask)
306
+ '''
307
+ #visualization
308
+ from os import path
309
+ if not path.exists('att_1.pt'):
310
+ torch.save(att[0], 'att_1.pt')
311
+ torch.save(top_index[0], 'att_ind1.pt')
312
+ elif not path.exists('att_2.pt'):
313
+ torch.save(att[0], 'att_2.pt')
314
+ torch.save(top_index[0], 'att_ind2.pt')
315
+ else:
316
+ torch.save(att[0], 'att_3.pt')
317
+ torch.save(top_index[0], 'att_ind3.pt')
318
+
319
+ if (not self.training) and is_print2:
320
+ if self.model!='diffpool':
321
+ visualize_tools.append(att[0])
322
+ visualize_tools.append(top_index[0])
323
+ visualize_tools.append(new_adj[0])
324
+ visualize_tools.append(new_mask.sum())
325
+ # print('**********************************')
326
+ return out,H,new_adj,new_mask,visualize_tools
327
+
328
+ def mean_pool(self,x,mask):
329
+ return x.sum(dim=1)/(self.eps+mask.sum(dim=1,keepdim=True))
330
+
331
+ def sum_pool(self,x,mask):
332
+ return x.sum(dim=1)
333
+
334
+ @staticmethod
335
+ def max_pool(x,mask):
336
+ #output: [batch,x.shape[2]]
337
+ m=(mask-1)*1e10
338
+ r,_=(x+m.unsqueeze(2)).max(dim=1)
339
+ return r
340
+ # GCN basic operation
341
+ class GCNBlock(nn.Module):
342
+ def __init__(self, input_dim, output_dim, bn=0,add_self=0, normalize_embedding=0,
343
+ dropout=0.0,relu=0, bias=True):
344
+ super(GCNBlock,self).__init__()
345
+ self.add_self = add_self
346
+ self.dropout = dropout
347
+ self.relu=relu
348
+ self.bn=bn
349
+ if dropout > 0.001:
350
+ self.dropout_layer = nn.Dropout(p=dropout)
351
+ if self.bn:
352
+ self.bn_layer = torch.nn.BatchNorm1d(output_dim)
353
+
354
+ self.normalize_embedding = normalize_embedding
355
+ self.input_dim = input_dim
356
+ self.output_dim = output_dim
357
+
358
+ self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).to( 'cuda' if torch.cuda.is_available() else 'cpu') )
359
+ torch.nn.init.xavier_normal_(self.weight)
360
+ if bias:
361
+ self.bias = nn.Parameter(torch.zeros(output_dim).to( 'cuda' if torch.cuda.is_available() else 'cpu') )
362
+ else:
363
+ self.bias = None
364
+
365
+ def forward(self, x, adj, mask):
366
+ y = torch.matmul(adj, x)
367
+ if self.add_self:
368
+ y += x
369
+ y = torch.matmul(y,self.weight)
370
+ if self.bias is not None:
371
+ y = y + self.bias
372
+ if self.normalize_embedding:
373
+ y = F.normalize(y, p=2, dim=2)
374
+ if self.bn:
375
+ index=mask.sum(dim=1).long().tolist()
376
+ bn_tensor_bf=mask.new_zeros((sum(index),y.shape[2]))
377
+ bn_tensor_af=mask.new_zeros(*y.shape)
378
+ start_index=[]
379
+ ssum=0
380
+ for i in range(x.shape[0]):
381
+ start_index.append(ssum)
382
+ ssum+=index[i]
383
+ start_index.append(ssum)
384
+ for i in range(x.shape[0]):
385
+ bn_tensor_bf[start_index[i]:start_index[i+1]]=y[i,0:index[i]]
386
+ bn_tensor_bf=self.bn_layer(bn_tensor_bf)
387
+ for i in range(x.shape[0]):
388
+ bn_tensor_af[i,0:index[i]]=bn_tensor_bf[start_index[i]:start_index[i+1]]
389
+ y=bn_tensor_af
390
+ if self.dropout > 0.001:
391
+ y = self.dropout_layer(y)
392
+ if self.relu=='relu':
393
+ y=torch.nn.functional.relu(y)
394
+ print('hahah')
395
+ elif self.relu=='lrelu':
396
+ y=torch.nn.functional.leaky_relu(y,0.1)
397
+ return y
398
+
399
+ #experimental function, untested
400
+ class masked_batchnorm(nn.Module):
401
+ def __init__(self,feat_dim,epsilon=1e-10):
402
+ super().__init__()
403
+ self.alpha=nn.Parameter(torch.ones(feat_dim))
404
+ self.beta=nn.Parameter(torch.zeros(feat_dim))
405
+ self.eps=epsilon
406
+
407
+ def forward(self,x,mask):
408
+ '''
409
+ x: node feat, [batch,node_num,feat_dim]
410
+ mask: [batch,node_num]
411
+ '''
412
+ mask1 = mask.unsqueeze(2)
413
+ mask_sum = mask.sum()
414
+ mean = x.sum(dim=(0,1),keepdim=True)/(self.eps+mask_sum)
415
+ temp = (x - mean)**2
416
+ temp = temp*mask1
417
+ var = temp.sum(dim=(0,1),keepdim=True)/(self.eps+mask_sum)
418
+ rstd = torch.rsqrt(var+self.eps)
419
+ x=(x-mean)*rstd
420
+ return ((x*self.alpha) + self.beta)*mask1
models/layers.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d',
6
+ 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect',
7
+ 'LayerNorm', 'AddEye']
8
+
9
+
10
+ def safe_divide(a, b):
11
+ den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
12
+ den = den + den.eq(0).type(den.type()) * 1e-9
13
+ return a / den * b.ne(0).type(b.type())
14
+
15
+
16
+ def forward_hook(self, input, output):
17
+ if type(input[0]) in (list, tuple):
18
+ self.X = []
19
+ for i in input[0]:
20
+ x = i.detach()
21
+ x.requires_grad = True
22
+ self.X.append(x)
23
+ else:
24
+ self.X = input[0].detach()
25
+ self.X.requires_grad = True
26
+
27
+ self.Y = output
28
+
29
+
30
+ def backward_hook(self, grad_input, grad_output):
31
+ self.grad_input = grad_input
32
+ self.grad_output = grad_output
33
+
34
+
35
+ class RelProp(nn.Module):
36
+ def __init__(self):
37
+ super(RelProp, self).__init__()
38
+ # if not self.training:
39
+ self.register_forward_hook(forward_hook)
40
+
41
+ def gradprop(self, Z, X, S):
42
+ C = torch.autograd.grad(Z, X, S, retain_graph=True)
43
+ return C
44
+
45
+ def relprop(self, R, alpha):
46
+ return R
47
+
48
+ class RelPropSimple(RelProp):
49
+ def relprop(self, R, alpha):
50
+ Z = self.forward(self.X)
51
+ S = safe_divide(R, Z)
52
+ C = self.gradprop(Z, self.X, S)
53
+
54
+ if torch.is_tensor(self.X) == False:
55
+ outputs = []
56
+ outputs.append(self.X[0] * C[0])
57
+ outputs.append(self.X[1] * C[1])
58
+ else:
59
+ outputs = self.X * (C[0])
60
+ return outputs
61
+
62
+ class AddEye(RelPropSimple):
63
+ # input of shape B, C, seq_len, seq_len
64
+ def forward(self, input):
65
+ return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
66
+
67
+ class ReLU(nn.ReLU, RelProp):
68
+ pass
69
+
70
+ class GELU(nn.GELU, RelProp):
71
+ pass
72
+
73
+ class Softmax(nn.Softmax, RelProp):
74
+ pass
75
+
76
+ class LayerNorm(nn.LayerNorm, RelProp):
77
+ pass
78
+
79
+ class Dropout(nn.Dropout, RelProp):
80
+ pass
81
+
82
+
83
+ class MaxPool2d(nn.MaxPool2d, RelPropSimple):
84
+ pass
85
+
86
+ class LayerNorm(nn.LayerNorm, RelProp):
87
+ pass
88
+
89
+ class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
90
+ pass
91
+
92
+
93
+ class AvgPool2d(nn.AvgPool2d, RelPropSimple):
94
+ pass
95
+
96
+
97
+ class Add(RelPropSimple):
98
+ def forward(self, inputs):
99
+ return torch.add(*inputs)
100
+
101
+ def relprop(self, R, alpha):
102
+ Z = self.forward(self.X)
103
+ S = safe_divide(R, Z)
104
+ C = self.gradprop(Z, self.X, S)
105
+
106
+ a = self.X[0] * C[0]
107
+ b = self.X[1] * C[1]
108
+
109
+ a_sum = a.sum()
110
+ b_sum = b.sum()
111
+
112
+ a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
113
+ b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
114
+
115
+ a = a * safe_divide(a_fact, a.sum())
116
+ b = b * safe_divide(b_fact, b.sum())
117
+
118
+ outputs = [a, b]
119
+
120
+ return outputs
121
+
122
+ class einsum(RelPropSimple):
123
+ def __init__(self, equation):
124
+ super().__init__()
125
+ self.equation = equation
126
+ def forward(self, *operands):
127
+ return torch.einsum(self.equation, *operands)
128
+
129
+ class IndexSelect(RelProp):
130
+ def forward(self, inputs, dim, indices):
131
+ self.__setattr__('dim', dim)
132
+ self.__setattr__('indices', indices)
133
+
134
+ return torch.index_select(inputs, dim, indices)
135
+
136
+ def relprop(self, R, alpha):
137
+ Z = self.forward(self.X, self.dim, self.indices)
138
+ S = safe_divide(R, Z)
139
+ C = self.gradprop(Z, self.X, S)
140
+
141
+ if torch.is_tensor(self.X) == False:
142
+ outputs = []
143
+ outputs.append(self.X[0] * C[0])
144
+ outputs.append(self.X[1] * C[1])
145
+ else:
146
+ outputs = self.X * (C[0])
147
+ return outputs
148
+
149
+
150
+
151
+ class Clone(RelProp):
152
+ def forward(self, input, num):
153
+ self.__setattr__('num', num)
154
+ outputs = []
155
+ for _ in range(num):
156
+ outputs.append(input)
157
+
158
+ return outputs
159
+
160
+ def relprop(self, R, alpha):
161
+ Z = []
162
+ for _ in range(self.num):
163
+ Z.append(self.X)
164
+ S = [safe_divide(r, z) for r, z in zip(R, Z)]
165
+ C = self.gradprop(Z, self.X, S)[0]
166
+
167
+ R = self.X * C
168
+
169
+ return R
170
+
171
+ class Cat(RelProp):
172
+ def forward(self, inputs, dim):
173
+ self.__setattr__('dim', dim)
174
+ return torch.cat(inputs, dim)
175
+
176
+ def relprop(self, R, alpha):
177
+ Z = self.forward(self.X, self.dim)
178
+ S = safe_divide(R, Z)
179
+ C = self.gradprop(Z, self.X, S)
180
+
181
+ outputs = []
182
+ for x, c in zip(self.X, C):
183
+ outputs.append(x * c)
184
+
185
+ return outputs
186
+
187
+
188
+ class Sequential(nn.Sequential):
189
+ def relprop(self, R, alpha):
190
+ for m in reversed(self._modules.values()):
191
+ R = m.relprop(R, alpha)
192
+ return R
193
+
194
+ class BatchNorm2d(nn.BatchNorm2d, RelProp):
195
+ def relprop(self, R, alpha):
196
+ X = self.X
197
+ beta = 1 - alpha
198
+ weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
199
+ (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
200
+ Z = X * weight + 1e-9
201
+ S = R / Z
202
+ Ca = S * weight
203
+ R = self.X * (Ca)
204
+ return R
205
+
206
+
207
+ class Linear(nn.Linear, RelProp):
208
+ def relprop(self, R, alpha):
209
+ beta = alpha - 1
210
+ pw = torch.clamp(self.weight, min=0)
211
+ nw = torch.clamp(self.weight, max=0)
212
+ px = torch.clamp(self.X, min=0)
213
+ nx = torch.clamp(self.X, max=0)
214
+
215
+ def f(w1, w2, x1, x2):
216
+ Z1 = F.linear(x1, w1)
217
+ Z2 = F.linear(x2, w2)
218
+ S1 = safe_divide(R, Z1 + Z2)
219
+ S2 = safe_divide(R, Z1 + Z2)
220
+ C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0]
221
+ C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0]
222
+
223
+ return C1 + C2
224
+
225
+ activator_relevances = f(pw, nw, px, nx)
226
+ inhibitor_relevances = f(nw, pw, px, nx)
227
+
228
+ R = alpha * activator_relevances - beta * inhibitor_relevances
229
+
230
+ return R
231
+
232
+
233
+ class Conv2d(nn.Conv2d, RelProp):
234
+ def gradprop2(self, DY, weight):
235
+ Z = self.forward(self.X)
236
+
237
+ output_padding = self.X.size()[2] - (
238
+ (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0])
239
+
240
+ return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding)
241
+
242
+ def relprop(self, R, alpha):
243
+ if self.X.shape[1] == 3:
244
+ pw = torch.clamp(self.weight, min=0)
245
+ nw = torch.clamp(self.weight, max=0)
246
+ X = self.X
247
+ L = self.X * 0 + \
248
+ torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
249
+ keepdim=True)[0]
250
+ H = self.X * 0 + \
251
+ torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
252
+ keepdim=True)[0]
253
+ Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
254
+ torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
255
+ torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
256
+
257
+ S = R / Za
258
+ C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
259
+ R = C
260
+ else:
261
+ beta = alpha - 1
262
+ pw = torch.clamp(self.weight, min=0)
263
+ nw = torch.clamp(self.weight, max=0)
264
+ px = torch.clamp(self.X, min=0)
265
+ nx = torch.clamp(self.X, max=0)
266
+
267
+ def f(w1, w2, x1, x2):
268
+ Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding)
269
+ Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding)
270
+ S1 = safe_divide(R, Z1)
271
+ S2 = safe_divide(R, Z2)
272
+ C1 = x1 * self.gradprop(Z1, x1, S1)[0]
273
+ C2 = x2 * self.gradprop(Z2, x2, S2)[0]
274
+ return C1 + C2
275
+
276
+ activator_relevances = f(pw, nw, px, nx)
277
+ inhibitor_relevances = f(nw, pw, px, nx)
278
+
279
+ R = alpha * activator_relevances - beta * inhibitor_relevances
280
+ return R
models/weight_init.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding:UTF-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.init as init
7
+
8
+
9
+ def weight_init(m):
10
+ '''
11
+ Usage:
12
+ model = Model()
13
+ model.apply(weight_init)
14
+ '''
15
+ if isinstance(m, nn.Conv1d):
16
+ init.normal_(m.weight.data)
17
+ if m.bias is not None:
18
+ init.normal_(m.bias.data)
19
+ elif isinstance(m, nn.Conv2d):
20
+ init.xavier_normal_(m.weight.data)
21
+ if m.bias is not None:
22
+ init.normal_(m.bias.data)
23
+ elif isinstance(m, nn.Conv3d):
24
+ init.xavier_normal_(m.weight.data)
25
+ if m.bias is not None:
26
+ init.normal_(m.bias.data)
27
+ elif isinstance(m, nn.ConvTranspose1d):
28
+ init.normal_(m.weight.data)
29
+ if m.bias is not None:
30
+ init.normal_(m.bias.data)
31
+ elif isinstance(m, nn.ConvTranspose2d):
32
+ init.xavier_normal_(m.weight.data)
33
+ if m.bias is not None:
34
+ init.normal_(m.bias.data)
35
+ elif isinstance(m, nn.ConvTranspose3d):
36
+ init.xavier_normal_(m.weight.data)
37
+ if m.bias is not None:
38
+ init.normal_(m.bias.data)
39
+ elif isinstance(m, nn.BatchNorm1d):
40
+ init.normal_(m.weight.data, mean=1, std=0.02)
41
+ init.constant_(m.bias.data, 0)
42
+ elif isinstance(m, nn.BatchNorm2d):
43
+ init.normal_(m.weight.data, mean=1, std=0.02)
44
+ init.constant_(m.bias.data, 0)
45
+ elif isinstance(m, nn.BatchNorm3d):
46
+ init.normal_(m.weight.data, mean=1, std=0.02)
47
+ init.constant_(m.bias.data, 0)
48
+ elif isinstance(m, nn.Linear):
49
+ init.xavier_normal_(m.weight.data)
50
+ init.normal_(m.bias.data)
51
+ elif isinstance(m, nn.LSTM):
52
+ for param in m.parameters():
53
+ if len(param.shape) >= 2:
54
+ init.orthogonal_(param.data)
55
+ else:
56
+ init.normal_(param.data)
57
+ elif isinstance(m, nn.LSTMCell):
58
+ for param in m.parameters():
59
+ if len(param.shape) >= 2:
60
+ init.orthogonal_(param.data)
61
+ else:
62
+ init.normal_(param.data)
63
+ elif isinstance(m, nn.GRU):
64
+ for param in m.parameters():
65
+ if len(param.shape) >= 2:
66
+ init.orthogonal_(param.data)
67
+ else:
68
+ init.normal_(param.data)
69
+ elif isinstance(m, nn.GRUCell):
70
+ for param in m.parameters():
71
+ if len(param.shape) >= 2:
72
+ init.orthogonal_(param.data)
73
+ else:
74
+ init.normal_(param.data)
75
+
76
+
77
+ if __name__ == '__main__':
78
+ pass
option.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###########################################################################
2
+ # Created by: YI ZHENG
3
+ # Email: yizheng@bu.edu
4
+ # Copyright (c) 2020
5
+ ###########################################################################
6
+
7
+ import os
8
+ import argparse
9
+ import torch
10
+
11
+ class Options():
12
+ def __init__(self):
13
+ parser = argparse.ArgumentParser(description='PyTorch Classification')
14
+ parser.add_argument('--data_path', type=str, help='path to dataset where images store')
15
+ parser.add_argument('--train_set', type=str, help='train')
16
+ parser.add_argument('--val_set', type=str, help='validation')
17
+ parser.add_argument('--model_path', type=str, help='path to trained model')
18
+ parser.add_argument('--log_path', type=str, help='path to log files')
19
+ parser.add_argument('--task_name', type=str, help='task name for naming saved model files and log files')
20
+ parser.add_argument('--train', action='store_true', default=False, help='train only')
21
+ parser.add_argument('--test', action='store_true', default=False, help='test only')
22
+ parser.add_argument('--batch_size', type=int, default=6, help='batch size for origin global image (without downsampling)')
23
+ parser.add_argument('--log_interval_local', type=int, default=10, help='classification classes')
24
+ parser.add_argument('--resume', type=str, default="", help='path for model')
25
+ parser.add_argument('--graphcam', action='store_true', default=False, help='GraphCAM')
26
+ parser.add_argument('--dataset_metadata_path', type=str, help='Location of the metadata associated with the created dataset: label mapping, splits and so on')
27
+
28
+
29
+ # the parser
30
+ self.parser = parser
31
+
32
+ def parse(self):
33
+ args = self.parser.parse_args()
34
+ # default settings for epochs and lr
35
+
36
+ args.num_epochs = 120
37
+ args.lr = 1e-3
38
+
39
+ if args.test:
40
+ args.num_epochs = 1
41
+ return args