|
import pandas as pd |
|
|
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
|
|
def make_attention_table(att, tokens, numb, token_idx = 0, layerNumb = -1): |
|
token_att = att[layerNumb, token_idx, range(1, len(tokens))] |
|
|
|
token_label=[] |
|
token_numb=[] |
|
for idx, token in enumerate(tokens[1:]) : |
|
token_label.append(f"<b>{token}</b>") |
|
token_numb.append(f"{idx}") |
|
|
|
pair = list(zip(token_numb, token_att)) |
|
|
|
df = pd.DataFrame(pair, columns=["Amino acid", "Attention rate"]) |
|
df.to_csv(f"amino_acid_seq_attention_{numb}.csv", index=None) |
|
|
|
top3_idx = sorted(range(len(token_att)), key=lambda i: token_att[i], reverse=True)[:3] |
|
|
|
colors = ['cornflowerblue', ] * len(token_numb) |
|
|
|
for i in top3_idx: |
|
colors[i] = 'crimson' |
|
|
|
fig = go.Figure(data=[go.Bar( |
|
x=df["Amino acid"], |
|
y=df["Attention rate"], |
|
|
|
marker_color=colors |
|
)]) |
|
|
|
|
|
|
|
fig.update_layout(plot_bgcolor="white") |
|
fig.update_xaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0)',mirror=False) |
|
fig.update_yaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0.07)',mirror=False) |
|
fig.update_layout(title={'text': "<b>Attention rate of amino acid sequence token</b>", |
|
'font':{'size':40}, |
|
'y': 0.96, |
|
'x': 0.5, |
|
'xanchor': 'center', |
|
'yanchor': 'top'}, |
|
|
|
xaxis=dict(tickmode='array', |
|
tickvals=token_numb, |
|
ticktext=token_label |
|
), |
|
|
|
xaxis_title={'text': "Amino acid sequence", |
|
'font':{'size':30}}, |
|
yaxis_title={'text': "Attention rate", |
|
'font':{'size':30}}, |
|
|
|
font=dict(family="Calibri, monospace", |
|
size=17 |
|
)) |
|
|
|
fig.write_image(f'figures/Amino_acid_seq_{numb}.png', width=1.5*1200, height=0.75*1200, scale=2) |
|
fig.show() |
|
|
|
|
|
def read_attention(): |
|
df = pd.read_csv("../amino_acid_seq_attention.csv") |
|
|
|
|
|
fig = px.bar(df, x="Amino acid", y="Attention rate", range_y=[min(df["Attention rate"]), max(df["Attention rate"])]) |
|
|
|
fig.update_layout(plot_bgcolor="white") |
|
fig.update_xaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0)',mirror=False) |
|
fig.update_yaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0.07)',mirror=False) |
|
fig.update_layout(title={'text': "<b>Attention rate of amino acid sequence token</b>", |
|
'font':{'size':40}, |
|
'y': 0.96, |
|
'x': 0.5, |
|
'xanchor': 'center', |
|
'yanchor': 'top'}, |
|
|
|
xaxis_title={'text': "Amino acid sequence", |
|
'font':{'size':30}}, |
|
yaxis_title={'text': "Attention rate", |
|
'font':{'size':30}}, |
|
|
|
font=dict(family="Calibri, monospace", |
|
size=17 |
|
)) |
|
|
|
fig.write_image('figures/Amino_acid_seq.png', width=1.5*1200, height=0.75*1200, scale=2) |
|
fig.show() |
|
|
|
if __name__ == '__main__': |
|
read_attention() |