|
import networkx as nx |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
import seaborn as sns |
|
import itertools |
|
import matplotlib as mpl |
|
|
|
|
|
rc={'font.size': 10, 'axes.labelsize': 10, 'legend.fontsize': 10.0, |
|
'axes.titlesize': 32, 'xtick.labelsize': 20, 'ytick.labelsize': 16} |
|
plt.rcParams.update(**rc) |
|
mpl.rcParams['axes.linewidth'] = .5 |
|
|
|
|
|
def plot_attention_heatmap(att, s_position, t_positions, input_tokens): |
|
|
|
cls_att = np.flip(att[:,s_position, t_positions], axis=0) |
|
xticklb = list(itertools.compress(input_tokens, [i in t_positions for i in np.arange(len(input_tokens))])) |
|
yticklb = [str(i) if i%2 ==0 else '' for i in np.arange(att.shape[0],0, -1)] |
|
ax = sns.heatmap(cls_att, xticklabels=xticklb, yticklabels=yticklb, cmap="YlOrRd") |
|
|
|
return ax |
|
|
|
def convert_adjmat_tomats(adjmat, n_layers, l): |
|
mats = np.zeros((n_layers,l,l)) |
|
|
|
for i in np.arange(n_layers): |
|
mats[i] = adjmat[(i+1)*l:(i+2)*l,i*l:(i+1)*l] |
|
|
|
return mats |
|
|
|
def make_residual_attention(attentions): |
|
all_attention = [att.detach().cpu().numpy() for att in attentions] |
|
attentions_mat = np.asarray(all_attention)[:,0] |
|
|
|
res_att_mat = attentions_mat.sum(axis=1)/attentions_mat.shape[1] |
|
res_att_mat = res_att_mat + np.eye(res_att_mat.shape[1])[None,...] |
|
res_att_mat = res_att_mat / res_att_mat.sum(axis=-1)[...,None] |
|
|
|
return attentions_mat, res_att_mat |
|
|
|
|
|
|
|
|
|
|
|
def make_flow_network(mat, input_tokens): |
|
n_layers, length, _ = mat.shape |
|
adj_mat = np.zeros(((n_layers+1)*length, (n_layers+1)*length)) |
|
labels_to_index = {} |
|
for k in np.arange(length): |
|
labels_to_index[str(k)+"_"+input_tokens[k]] = k |
|
|
|
for i in np.arange(1,n_layers+1): |
|
for k_f in np.arange(length): |
|
index_from = (i)*length+k_f |
|
label = "L"+str(i)+"_"+str(k_f) |
|
labels_to_index[label] = index_from |
|
for k_t in np.arange(length): |
|
index_to = (i-1)*length+k_t |
|
adj_mat[index_from][index_to] = mat[i-1][k_f][k_t] |
|
|
|
net_graph=nx.from_numpy_matrix(adj_mat, create_using=nx.DiGraph()) |
|
for i in np.arange(adj_mat.shape[0]): |
|
for j in np.arange(adj_mat.shape[1]): |
|
nx.set_edge_attributes(net_graph, {(i,j): adj_mat[i,j]}, 'capacity') |
|
|
|
return net_graph, labels_to_index |
|
|
|
|
|
def make_input_node(attention_mat, res_labels_to_index): |
|
input_nodes = [] |
|
for key in res_labels_to_index: |
|
if res_labels_to_index[key] < attention_mat.shape[-1]: |
|
input_nodes.append(key) |
|
|
|
return input_nodes |
|
|
|
|
|
|
|
|
|
|
|
def get_adjmat(mat, input_tokens): |
|
n_layers, length, _ = mat.shape |
|
adj_mat = np.zeros(((n_layers+1)*length, (n_layers+1)*length)) |
|
labels_to_index = {} |
|
for k in np.arange(length): |
|
labels_to_index[str(k)+"_"+input_tokens[k]] = k |
|
|
|
for i in np.arange(1,n_layers+1): |
|
for k_f in np.arange(length): |
|
index_from = (i)*length+k_f |
|
label = "L"+str(i)+"_"+str(k_f) |
|
labels_to_index[label] = index_from |
|
for k_t in np.arange(length): |
|
index_to = (i-1)*length+k_t |
|
adj_mat[index_from][index_to] = mat[i-1][k_f][k_t] |
|
|
|
return adj_mat, labels_to_index |
|
|
|
def draw_attention_graph(adjmat, labels_to_index, n_layers, length): |
|
A = adjmat |
|
net_graph=nx.from_numpy_matrix(A, create_using=nx.DiGraph()) |
|
for i in np.arange(A.shape[0]): |
|
for j in np.arange(A.shape[1]): |
|
nx.set_edge_attributes(net_graph, {(i,j): A[i,j]}, 'capacity') |
|
|
|
pos = {} |
|
label_pos = {} |
|
for i in np.arange(n_layers+1): |
|
for k_f in np.arange(length): |
|
pos[i*length+k_f] = ((i+0.4)*2, length - k_f) |
|
label_pos[i*length+k_f] = (i*2, length - k_f) |
|
|
|
index_to_labels = {} |
|
for key in labels_to_index: |
|
index_to_labels[labels_to_index[key]] = key.split("_")[-1] |
|
if labels_to_index[key] >= length: |
|
index_to_labels[labels_to_index[key]] = '' |
|
|
|
|
|
nx.draw_networkx_nodes(net_graph,pos,node_color='green', labels=index_to_labels, node_size=50) |
|
nx.draw_networkx_labels(net_graph,pos=label_pos, labels=index_to_labels, font_size=18) |
|
|
|
all_weights = [] |
|
|
|
for (node1,node2,data) in net_graph.edges(data=True): |
|
all_weights.append(data['weight']) |
|
|
|
|
|
unique_weights = list(set(all_weights)) |
|
|
|
|
|
for weight in unique_weights: |
|
|
|
weighted_edges = [(node1,node2) for (node1,node2,edge_attr) in net_graph.edges(data=True) if edge_attr['weight']==weight] |
|
|
|
|
|
w = weight |
|
width = w |
|
nx.draw_networkx_edges(net_graph,pos,edgelist=weighted_edges,width=width, edge_color='darkblue') |
|
|
|
return net_graph |
|
|
|
def compute_flows(G, labels_to_index, input_nodes, length): |
|
number_of_nodes = len(labels_to_index) |
|
flow_values=np.zeros((number_of_nodes,number_of_nodes)) |
|
for key in tqdm(labels_to_index, desc="flow algorithms", total=len(labels_to_index)): |
|
if key not in input_nodes: |
|
current_layer = int(labels_to_index[key] / length) |
|
pre_layer = current_layer - 1 |
|
u = labels_to_index[key] |
|
for inp_node_key in input_nodes: |
|
v = labels_to_index[inp_node_key] |
|
flow_value = nx.maximum_flow_value(G,u,v, flow_func=nx.algorithms.flow.edmonds_karp) |
|
|
|
flow_values[u][pre_layer*length+v ] = flow_value |
|
flow_values[u] /= flow_values[u].sum() |
|
|
|
return flow_values |
|
|
|
def compute_node_flow(G, labels_to_index, input_nodes, output_nodes,length): |
|
number_of_nodes = len(labels_to_index) |
|
flow_values=np.zeros((number_of_nodes,number_of_nodes)) |
|
for key in output_nodes: |
|
if key not in input_nodes: |
|
current_layer = int(labels_to_index[key] / length) |
|
pre_layer = current_layer - 1 |
|
u = labels_to_index[key] |
|
for inp_node_key in input_nodes: |
|
v = labels_to_index[inp_node_key] |
|
flow_value = nx.maximum_flow_value(G,u,v, flow_func=nx.algorithms.flow.edmonds_karp) |
|
flow_values[u][pre_layer*length+v ] = flow_value |
|
flow_values[u] /= flow_values[u].sum() |
|
|
|
return flow_values |
|
|
|
def compute_joint_attention(att_mat, add_residual=True): |
|
if add_residual: |
|
residual_att = np.eye(att_mat.shape[1])[None,...] |
|
aug_att_mat = att_mat + residual_att |
|
aug_att_mat = aug_att_mat / aug_att_mat.sum(axis=-1)[...,None] |
|
else: |
|
aug_att_mat = att_mat |
|
|
|
joint_attentions = np.zeros(aug_att_mat.shape) |
|
|
|
layers = joint_attentions.shape[0] |
|
joint_attentions[0] = aug_att_mat[0] |
|
for i in np.arange(1,layers): |
|
joint_attentions[i] = aug_att_mat[i].dot(joint_attentions[i-1]) |
|
|
|
return joint_attentions |