Dekode's picture
Update app.py
e32d729 verified
raw
history blame contribute delete
No virus
8.27 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModel, utils
from bertviz import model_view
import streamlit.components.v1 as components
from train import get_or_build_tokenizer, greedy_decode
from config import get_config, latest_weights_file_path
from model import build_transformer
import torch
from bertviz import model_view
import torch
import altair as alt
import pandas as pd
import warnings
warnings.filterwarnings("ignore")
utils.logging.set_verbosity_error() # Suppress standard warnings
st.set_page_config(page_title='Attention Visualizer', layout='wide')
def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
return pd.DataFrame(
[
(
r,
c,
float(m[r, c]),
"%.2d - %s" % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),
"%.2d - %s" % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"),
)
for r in range(m.shape[0])
for c in range(m.shape[1])
if r < max_row and c < max_col
],
columns=["row", "column", "value", "row_token", "col_token"],
)
def get_attn_map(attn_type: str, layer: int, head: int, model):
if attn_type == "encoder":
attn = model.encoder.layers[layer].self_attention_block.attention_scores
elif attn_type == "decoder":
attn = model.decoder.layers[layer].self_attention_block.attention_scores
elif attn_type == "encoder-decoder":
attn = model.decoder.layers[layer].cross_attention_block.attention_scores
return attn[0, head].data
def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len, model):
df = mtx2df(
get_attn_map(attn_type, layer, head, model),
max_sentence_len,
max_sentence_len,
row_tokens,
col_tokens,
)
return (
alt.Chart(data=df)
.mark_rect()
.encode(
x=alt.X("col_token", axis=alt.Axis(title="")),
y=alt.Y("row_token", axis=alt.Axis(title="")),
color=alt.Color("value", scale=alt.Scale(scheme="blues")),
tooltip=["row", "column", "value", "row_token", "col_token"],
)
#.title(f"Layer {layer} Head {head}")
.properties(height=200, width=200, title=f"Layer {layer} Head {head}")
.interactive()
)
def get_all_attention_maps(attn_type: str, layers: list[int], heads: list[int], row_tokens: list, col_tokens, max_sentence_len: int, model):
charts = []
for layer in layers:
rowCharts = []
for head in heads:
rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len, model))
charts.append(alt.hconcat(*rowCharts))
return alt.vconcat(*charts)
def initiate_model(config, device):
tokenizer_src = get_or_build_tokenizer(config, None, config["lang_src"])
tokenizer_tgt = get_or_build_tokenizer(config, None, config["lang_tgt"])
model = build_transformer(tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size(), config["seq_len"], config['seq_len'], d_model=config['d_model']).to(device)
model_filename = latest_weights_file_path(config)
state = torch.load(model_filename, map_location=torch.device('cpu'))
model.load_state_dict(state['model_state_dict'])
return model, tokenizer_src, tokenizer_tgt
def process_input(input_text, tokenizer_src, tokenizer_tgt, model, config, device):
src = tokenizer_src.encode(input_text)
src = torch.cat([
torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64),
torch.tensor(src.ids, dtype=torch.int64),
torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64),
torch.tensor([tokenizer_src.token_to_id('[PAD]')] * (config['seq_len'] - len(src.ids) - 2), dtype=torch.int64)
], dim=0).to(device)
source_mask = (src != tokenizer_src.token_to_id('[PAD]')).unsqueeze(0).unsqueeze(0).int().to(device)
encoder_input_tokens = [tokenizer_src.id_to_token(i) for i in src.cpu().numpy()]
encoder_input_tokens = [i for i in encoder_input_tokens if i != '[PAD]']
model_out = greedy_decode(model, src, source_mask, tokenizer_src, tokenizer_tgt, config['seq_len'], device)
decoder_input_tokens = [tokenizer_tgt.id_to_token(i) for i in model_out.cpu().numpy()]
output = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
return encoder_input_tokens, decoder_input_tokens, output
# def get_html_data(model_name, input_text):
# model_name ="microsoft/xtremedistil-l12-h384-uncased"
# model = AutoModel.from_pretrained(model_name, output_attentions=True, cache_dir='__pycache__') # Configure model to return attention values
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# inputs = tokenizer.encode(input_text, return_tensors='pt') # Tokenize input text
# outputs = model(inputs) # Run model
# attention = outputs[-1] # Retrieve attention from model outputs
# tokens = tokenizer.convert_ids_to_tokens(inputs[0]) # Convert input ids to token strings
# model_html = model_view(attention, tokens, html_action="return") # Display model view
# with open("static/model_view.html", 'w') as file:
# file.write(model_html.data)
def main():
st.title('Transformer Visualizer')
# st.info('Enter a sentence to visualize the attention of the model')
st.write('This app visualizes the attention of a transformer model on a given sentence.')
# add a side bar with model options and a prompt
config = get_config()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, tokenizer_src, tokenizer_tgt = initiate_model(config, device)
with st.sidebar:
input_text = st.text_input('Enter a sentence')
# put two buttons side by side in the sidebar
# translate_button = st.button('Translate', key='translate_button')
# viz_button = st.button('Visualize Attention', key='viz_button')
attn_type = st.selectbox('Select attention type', ['encoder', 'decoder', 'encoder-decoder'])
layers = st.multiselect('Select layers', list(range(6)))
heads = st.multiselect('Select heads', list(range(8)))
# allow the user to select the all the layers and heads at once to visualize
if st.checkbox('Select all layers'):
layers = list(range(6))
if st.checkbox('Select all heads'):
heads = list(range(8))
if input_text != '':
with st.spinner("Translating..."):
encoder_input_tokens, decoder_input_tokens, output = process_input(input_text, tokenizer_src, tokenizer_tgt, model, config, device)
max_sentence_len = len(encoder_input_tokens)
row_tokens = encoder_input_tokens
col_tokens = decoder_input_tokens
st.write('Input:', ' '.join(encoder_input_tokens))
st.write('Output:', ' '.join(decoder_input_tokens))
st.write('Translated:', output)
st.write('Attention Visualization')
with st.spinner("Visualizing Attention..."):
if attn_type == 'encoder':
st.write(get_all_attention_maps(attn_type, layers, heads, row_tokens, row_tokens, max_sentence_len, model))
elif attn_type == 'decoder':
st.write(get_all_attention_maps(attn_type, layers, heads, col_tokens, col_tokens, max_sentence_len, model))
elif attn_type == 'encoder-decoder':
st.write(get_all_attention_maps(attn_type, layers, heads, row_tokens, col_tokens, max_sentence_len, model))
else:
st.write('Enter a sentence to visualize the attention of the model')
# add a footer with the github repo link and dataset link
st.markdown('---')
st.write('Made by [Pratik Dwivedi](https://github.com/Dekode1859)')
st.write('Check out the Scratch Implementation and Visualizer Code on [GitHub](https://github.com/Dekode1859/transformer-visualizer)')
st.write('Dataset: [Opus-books: english-Italian](https://huggingface.co/datasets/Helsinki-NLP/opus_books)')
# st.write('This app is a Streamlit implementation of the [BERTViz](
if __name__ == '__main__':
main()