taquynhnga commited on
Commit
8a287fa
·
1 Parent(s): 81f0a60

remove st.tabs

Browse files
.gitattributes CHANGED
@@ -1,4 +1,4 @@
1
  *.pkl filter=lfs diff=lfs merge=lfs -text
2
- *.json filter=lfs diff=lfs merge=lfs -text
3
  .csv filter=lfs diff=lfs merge=lfs -text
4
  data/** filter=lfs diff=lfs merge=lfs -text
 
 
1
  *.pkl filter=lfs diff=lfs merge=lfs -text
 
2
  .csv filter=lfs diff=lfs merge=lfs -text
3
  data/** filter=lfs diff=lfs merge=lfs -text
4
+ Visual-Explanation-Methods-PyTorch/** filter=lfs diff=lfs merge=lfs -text
pages/1_Maximally_activating_patches.py CHANGED
@@ -64,98 +64,96 @@ props = {
64
  }
65
  }
66
 
67
- convnext_tab, resnet_tab, mobilenet_tab = st.tabs(['ConvNeXt', 'ResNet', 'MobileNet'])
68
 
69
- with convnext_tab:
70
- col1, col2 = st.columns((2,5))
71
- col1.markdown("#### Architecture")
72
- col1.write('')
73
- col1.write('Click on a layer below to generate top-k maximally activating image patches')
74
- col1.graphviz_chart(convnext_graph)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  with col2:
77
- st.markdown("#### Output")
78
- nodes = on_click_graph(key='toggle_buttons', **props)
79
-
80
- # -------------------------- DISPLAY OUTPUT -----------------------------------
81
-
82
- if nodes != None:
83
- clicked_node_title = nodes["choice"]["node_title"]
84
- clicked_node_id = nodes["choice"]["node_id"]
85
- display_text, activation_key = chosen_node_text(clicked_node_title)
86
- col2.write(f'**Chosen layer:** {display_text}')
87
- # col2.write(f'**Activation key:** {activation_key}')
88
-
89
- hightlight_syle = f'''
90
- <style>
91
- div[data-stale]:has(iframe) {{
92
- height: 0;
93
- }}
94
- #{clicked_node_id}>polygon {{
95
- fill: {HIGHTLIGHT_COLOR};
96
- stroke: {HIGHTLIGHT_COLOR};
97
- }}
98
- </style>
99
- '''
100
- col2.markdown(hightlight_syle, unsafe_allow_html=True)
101
-
102
- with col2:
103
- layer_infos = None
104
- with st.form('top_k_form'):
105
- activation_path = './data/activation/convnext_activation.json'
106
- activation = load_activation(activation_path)
107
- num_channels = activation[activation_key].shape[1]
108
-
109
- top_k = st.slider('Choose K for top-K maximally activating patches', 1,20, value=10)
110
- channel_start, channel_end = st.slider(
111
- 'Choose channel range of this layer (recommend to choose small range less than 30)',
112
- 1, num_channels, value=(1, 30))
113
- summit_button = st.form_submit_button('Generate image patches')
114
- if summit_button:
 
 
 
 
 
 
115
 
116
- activation = activation[activation_key][:top_k,:,:]
117
- layer_infos = load_layer_infos('./data/layer_infos/convnext_layer_infos.json')
118
- # st.write(channel_start, channel_end)
119
- # st.write(activation.shape, activation.shape[1])
120
-
121
- if layer_infos != None:
122
- num_cols, num_rows = top_k, channel_end - channel_start + 1
123
- # num_rows = activation.shape[1]
124
- top_k_coor_max_ = activation
125
- st.markdown(f"#### Top-{top_k} maximally activating image patches of {num_rows} channels ({channel_start}-{channel_end})")
126
-
127
- for row in range(channel_start, channel_end+1):
128
- if row == channel_start:
129
- top_margin = 50
130
- fig = make_subplots(
131
- rows=1, cols=num_cols,
132
- subplot_titles=tuple([f"#{i+1}" for i in range(top_k)]), shared_yaxes=True)
133
- else:
134
- top_margin = 0
135
- fig = make_subplots(rows=1, cols=num_cols)
136
- for col in range(1, num_cols+1):
137
- k, c = col-1, row-1
138
- img_index = int(top_k_coor_max_[k, c, 3])
139
- activation_value = top_k_coor_max_[k, c, 0]
140
- img = dataset_dict[img_index//10_000][img_index%10_000]['image']
141
- class_label = dataset_dict[img_index//10_000][img_index%10_000]['label']
142
- class_id = dataset_dict[img_index//10_000][img_index%10_000]['id']
143
-
144
- idx_x, idx_y = top_k_coor_max_[k, c, 1], top_k_coor_max_[k, c, 2]
145
- x1, x2, y1, y2 = get_receptive_field_coordinates(layer_infos, activation_key, idx_x, idx_y)
146
- img = np.array(img)[y1:y2, x1:x2, :]
147
-
148
- hovertemplate = f"""Top-{col}<br>Activation value: {activation_value:.5f}<br>Class Label: {class_label}<br>Class id: {class_id}<br>Image id: {img_index}"""
149
- fig.add_trace(go.Image(z=img, hovertemplate=hovertemplate), row=1, col=col)
150
- fig.update_xaxes(showticklabels=False, showgrid=False)
151
- fig.update_yaxes(showticklabels=False, showgrid=False)
152
- fig.update_layout(margin={'b':0, 't':top_margin, 'r':0, 'l':0})
153
- fig.update_layout(showlegend=False, yaxis_title=row)
154
- fig.update_layout(height=100, plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)')
155
- fig.update_layout(hoverlabel=dict(bgcolor="#e9f2f7"))
156
- st.plotly_chart(fig, use_container_width=True)
157
-
158
-
159
- else:
160
- col2.markdown(f'Chosen layer: <code>None</code>', unsafe_allow_html=True)
161
- col2.markdown("""<style>div[data-stale]:has(iframe) {height: 0};""", unsafe_allow_html=True)
 
64
  }
65
  }
66
 
 
67
 
68
+ col1, col2 = st.columns((2,5))
69
+ col1.markdown("#### Architecture")
70
+ col1.write('')
71
+ col1.write('Click on a layer below to generate top-k maximally activating image patches')
72
+ col1.graphviz_chart(convnext_graph)
73
+
74
+ with col2:
75
+ st.markdown("#### Output")
76
+ nodes = on_click_graph(key='toggle_buttons', **props)
77
+
78
+ # -------------------------- DISPLAY OUTPUT -----------------------------------
79
+
80
+ if nodes != None:
81
+ clicked_node_title = nodes["choice"]["node_title"]
82
+ clicked_node_id = nodes["choice"]["node_id"]
83
+ display_text, activation_key = chosen_node_text(clicked_node_title)
84
+ col2.write(f'**Chosen layer:** {display_text}')
85
+ # col2.write(f'**Activation key:** {activation_key}')
86
+
87
+ hightlight_syle = f'''
88
+ <style>
89
+ div[data-stale]:has(iframe) {{
90
+ height: 0;
91
+ }}
92
+ #{clicked_node_id}>polygon {{
93
+ fill: {HIGHTLIGHT_COLOR};
94
+ stroke: {HIGHTLIGHT_COLOR};
95
+ }}
96
+ </style>
97
+ '''
98
+ col2.markdown(hightlight_syle, unsafe_allow_html=True)
99
 
100
  with col2:
101
+ layer_infos = None
102
+ with st.form('top_k_form'):
103
+ activation_path = './data/activation/convnext_activation.json'
104
+ activation = load_activation(activation_path)
105
+ num_channels = activation[activation_key].shape[1]
106
+
107
+ top_k = st.slider('Choose K for top-K maximally activating patches', 1,20, value=10)
108
+ channel_start, channel_end = st.slider(
109
+ 'Choose channel range of this layer (recommend to choose small range less than 30)',
110
+ 1, num_channels, value=(1, 30))
111
+ summit_button = st.form_submit_button('Generate image patches')
112
+ if summit_button:
113
+
114
+ activation = activation[activation_key][:top_k,:,:]
115
+ layer_infos = load_layer_infos('./data/layer_infos/convnext_layer_infos.json')
116
+ # st.write(channel_start, channel_end)
117
+ # st.write(activation.shape, activation.shape[1])
118
+
119
+ if layer_infos != None:
120
+ num_cols, num_rows = top_k, channel_end - channel_start + 1
121
+ # num_rows = activation.shape[1]
122
+ top_k_coor_max_ = activation
123
+ st.markdown(f"#### Top-{top_k} maximally activating image patches of {num_rows} channels ({channel_start}-{channel_end})")
124
+
125
+ for row in range(channel_start, channel_end+1):
126
+ if row == channel_start:
127
+ top_margin = 50
128
+ fig = make_subplots(
129
+ rows=1, cols=num_cols,
130
+ subplot_titles=tuple([f"#{i+1}" for i in range(top_k)]), shared_yaxes=True)
131
+ else:
132
+ top_margin = 0
133
+ fig = make_subplots(rows=1, cols=num_cols)
134
+ for col in range(1, num_cols+1):
135
+ k, c = col-1, row-1
136
+ img_index = int(top_k_coor_max_[k, c, 3])
137
+ activation_value = top_k_coor_max_[k, c, 0]
138
+ img = dataset_dict[img_index//10_000][img_index%10_000]['image']
139
+ class_label = dataset_dict[img_index//10_000][img_index%10_000]['label']
140
+ class_id = dataset_dict[img_index//10_000][img_index%10_000]['id']
141
+
142
+ idx_x, idx_y = top_k_coor_max_[k, c, 1], top_k_coor_max_[k, c, 2]
143
+ x1, x2, y1, y2 = get_receptive_field_coordinates(layer_infos, activation_key, idx_x, idx_y)
144
+ img = np.array(img)[y1:y2, x1:x2, :]
145
 
146
+ hovertemplate = f"""Top-{col}<br>Activation value: {activation_value:.5f}<br>Class Label: {class_label}<br>Class id: {class_id}<br>Image id: {img_index}"""
147
+ fig.add_trace(go.Image(z=img, hovertemplate=hovertemplate), row=1, col=col)
148
+ fig.update_xaxes(showticklabels=False, showgrid=False)
149
+ fig.update_yaxes(showticklabels=False, showgrid=False)
150
+ fig.update_layout(margin={'b':0, 't':top_margin, 'r':0, 'l':0})
151
+ fig.update_layout(showlegend=False, yaxis_title=row)
152
+ fig.update_layout(height=100, plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)')
153
+ fig.update_layout(hoverlabel=dict(bgcolor="#e9f2f7"))
154
+ st.plotly_chart(fig, use_container_width=True)
155
+
156
+
157
+ else:
158
+ col2.markdown(f'Chosen layer: <code>None</code>', unsafe_allow_html=True)
159
+ col2.markdown("""<style>div[data-stale]:has(iframe) {height: 0};""", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -10,7 +10,8 @@ Pillow==9.3.0
10
  plotly==5.11.0
11
  scipy==1.9.3
12
  setuptools==65.5.0
13
- streamlit==1.15.2
 
14
  torch==1.10.1
15
  torchvision==0.11.2
16
  tqdm==4.64.1
 
10
  plotly==5.11.0
11
  scipy==1.9.3
12
  setuptools==65.5.0
13
+ # streamlit==1.15.2
14
+ streamlit==1.10.0
15
  torch==1.10.1
16
  torchvision==0.11.2
17
  tqdm==4.64.1