taka-yamakoshi
commited on
Commit
•
9874228
1
Parent(s):
3cc4ad8
plot
Browse files
app.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
import streamlit as st
|
4 |
-
|
5 |
-
|
6 |
|
7 |
#import jax
|
8 |
#import jax.numpy as jnp
|
@@ -169,7 +169,6 @@ if __name__=='__main__':
|
|
169 |
load_css('style.css')
|
170 |
tokenizer,model = load_model()
|
171 |
num_layers, num_heads = model.config.num_hidden_layers, model.config.num_attention_heads
|
172 |
-
st.write(num_layers,num_heads)
|
173 |
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
174 |
|
175 |
main_area = st.empty()
|
@@ -260,11 +259,20 @@ if __name__=='__main__':
|
|
260 |
st.dataframe(df.style.highlight_max(axis=1))
|
261 |
|
262 |
multihead = True
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
import streamlit as st
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import seaborn as sns
|
6 |
|
7 |
#import jax
|
8 |
#import jax.numpy as jnp
|
|
|
169 |
load_css('style.css')
|
170 |
tokenizer,model = load_model()
|
171 |
num_layers, num_heads = model.config.num_hidden_layers, model.config.num_attention_heads
|
|
|
172 |
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
173 |
|
174 |
main_area = st.empty()
|
|
|
259 |
st.dataframe(df.style.highlight_max(axis=1))
|
260 |
|
261 |
multihead = True
|
262 |
+
effect_array = []
|
263 |
+
for token_id in range(1,len(masked_ids_option_1['sent_1'])-1):
|
264 |
+
effect_list = []
|
265 |
+
for layer_id in range(num_layers):
|
266 |
+
interventions = [create_interventions(token_id,['lay','qry','key','val'],num_heads,multihead) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
267 |
+
if multihead:
|
268 |
+
probs = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
269 |
+
else:
|
270 |
+
probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
271 |
+
effect = ((probs_original-probs)[0,0] + (probs_original-probs)[1,1] + (probs-probs_original)[0,1] + (probs-probs_original)[1,0])/4
|
272 |
+
effect_list.append(effect)
|
273 |
+
effect_array.append(effect_list)
|
274 |
+
effects = np.array(effect_array)
|
275 |
+
|
276 |
+
fig,ax = plt.subplots(1,1,figsize=(8,6))
|
277 |
+
ax.imshow(effects.T)
|
278 |
+
st.pyplot(fig)
|