makenmtviz / bertviz_gradio.py
Gabriela Nicole Gonzalez Saez
Add files
056bbdc
raw
history blame
No virus
11 kB
import json
import os
import uuid
from IPython.core.display import display, HTML, Javascript
from bertviz.util import format_special_chars, format_attention, num_layers
def head_view_mod(
attention=None,
tokens=None,
sentence_b_start=None,
prettify_tokens=True,
layer=None,
heads=None,
encoder_attention=None,
decoder_attention=None,
cross_attention=None,
encoder_tokens=None,
decoder_tokens=None,
include_layers=None,
html_action='view'
):
"""Render head view
Args:
For self-attention models:
attention: list of ``torch.FloatTensor``(one for each layer) of shape
``(batch_size(must be 1), num_heads, sequence_length, sequence_length)``
tokens: list of tokens
sentence_b_start: index of first wordpiece in sentence B if input text is sentence pair (optional)
For encoder-decoder models:
encoder_attention: list of ``torch.FloatTensor``(one for each layer) of shape
``(batch_size(must be 1), num_heads, encoder_sequence_length, encoder_sequence_length)``
decoder_attention: list of ``torch.FloatTensor``(one for each layer) of shape
``(batch_size(must be 1), num_heads, decoder_sequence_length, decoder_sequence_length)``
cross_attention: list of ``torch.FloatTensor``(one for each layer) of shape
``(batch_size(must be 1), num_heads, decoder_sequence_length, encoder_sequence_length)``
encoder_tokens: list of tokens for encoder input
decoder_tokens: list of tokens for decoder input
For all models:
prettify_tokens: indicates whether to remove special characters in wordpieces, e.g. Ġ
layer: index (zero-based) of initial selected layer in visualization. Defaults to layer 0.
heads: Indices (zero-based) of initial selected heads in visualization. Defaults to all heads.
include_layers: Indices (zero-based) of layers to include in visualization. Defaults to all layers.
Note: filtering layers may improve responsiveness of the visualization for long inputs.
html_action: Specifies the action to be performed with the generated HTML object
- 'view' (default): Displays the generated HTML representation as a notebook cell output
- 'return' : Returns an HTML object containing the generated view for further processing or custom visualization
"""
attn_data = []
if attention is not None:
if tokens is None:
raise ValueError("'tokens' is required")
if encoder_attention is not None or decoder_attention is not None or cross_attention is not None \
or encoder_tokens is not None or decoder_tokens is not None:
raise ValueError("If you specify 'attention' you may not specify any encoder-decoder arguments. This"
" argument is only for self-attention models.")
if include_layers is None:
include_layers = list(range(num_layers(attention)))
attention = format_attention(attention, include_layers)
if sentence_b_start is None:
attn_data.append(
{
'name': None,
'attn': attention.tolist(),
'left_text': tokens,
'right_text': tokens
}
)
else:
slice_a = slice(0, sentence_b_start) # Positions corresponding to sentence A in input
slice_b = slice(sentence_b_start, len(tokens)) # Position corresponding to sentence B in input
attn_data.append(
{
'name': 'All',
'attn': attention.tolist(),
'left_text': tokens,
'right_text': tokens
}
)
attn_data.append(
{
'name': 'Sentence A -> Sentence A',
'attn': attention[:, :, slice_a, slice_a].tolist(),
'left_text': tokens[slice_a],
'right_text': tokens[slice_a]
}
)
attn_data.append(
{
'name': 'Sentence B -> Sentence B',
'attn': attention[:, :, slice_b, slice_b].tolist(),
'left_text': tokens[slice_b],
'right_text': tokens[slice_b]
}
)
attn_data.append(
{
'name': 'Sentence A -> Sentence B',
'attn': attention[:, :, slice_a, slice_b].tolist(),
'left_text': tokens[slice_a],
'right_text': tokens[slice_b]
}
)
attn_data.append(
{
'name': 'Sentence B -> Sentence A',
'attn': attention[:, :, slice_b, slice_a].tolist(),
'left_text': tokens[slice_b],
'right_text': tokens[slice_a]
}
)
elif encoder_attention is not None or decoder_attention is not None or cross_attention is not None:
if encoder_attention is not None:
if encoder_tokens is None:
raise ValueError("'encoder_tokens' required if 'encoder_attention' is not None")
if include_layers is None:
include_layers = list(range(num_layers(encoder_attention)))
encoder_attention = format_attention(encoder_attention, include_layers)
attn_data.append(
{
'name': 'Encoder',
'attn': encoder_attention.tolist(),
'left_text': encoder_tokens,
'right_text': encoder_tokens
}
)
if decoder_attention is not None:
if decoder_tokens is None:
raise ValueError("'decoder_tokens' required if 'decoder_attention' is not None")
if include_layers is None:
include_layers = list(range(num_layers(decoder_attention)))
decoder_attention = format_attention(decoder_attention, include_layers)
attn_data.append(
{
'name': 'Decoder',
'attn': decoder_attention.tolist(),
'left_text': decoder_tokens,
'right_text': decoder_tokens
}
)
if cross_attention is not None:
if encoder_tokens is None:
raise ValueError("'encoder_tokens' required if 'cross_attention' is not None")
if decoder_tokens is None:
raise ValueError("'decoder_tokens' required if 'cross_attention' is not None")
if include_layers is None:
include_layers = list(range(num_layers(cross_attention)))
cross_attention = format_attention(cross_attention, include_layers)
attn_data.append(
{
'name': 'Cross',
'attn': cross_attention.tolist(),
'left_text': decoder_tokens,
'right_text': encoder_tokens
}
)
else:
raise ValueError("You must specify at least one attention argument.")
if layer is not None and layer not in include_layers:
raise ValueError(f"Layer {layer} is not in include_layers: {include_layers}")
# Generate unique div id to enable multiple visualizations in one notebook
# vis_id = 'bertviz-%s'%(uuid.uuid4().hex)
vis_id = 'bertviz'#-%s'%(uuid.uuid4().hex)
# Compose html
if len(attn_data) > 1:
options = '\n'.join(
f'<option value="{i}">{attn_data[i]["name"]}</option>'
for i, d in enumerate(attn_data)
)
select_html = f'Attention: <select id="filter">{options}</select>'
else:
select_html = ""
vis_html = f"""
<div id="{vis_id}" style="font-family:'Helvetica Neue', Helvetica, Arial, sans-serif;">
<span style="user-select:none">
Layer: <select id="layer"></select>
{select_html}
</span>
<div id='vis'></div>
</div>
"""
for d in attn_data:
attn_seq_len_left = len(d['attn'][0][0])
if attn_seq_len_left != len(d['left_text']):
raise ValueError(
f"Attention has {attn_seq_len_left} positions, while number of tokens is {len(d['left_text'])} "
f"for tokens: {' '.join(d['left_text'])}"
)
attn_seq_len_right = len(d['attn'][0][0][0])
if attn_seq_len_right != len(d['right_text']):
raise ValueError(
f"Attention has {attn_seq_len_right} positions, while number of tokens is {len(d['right_text'])} "
f"for tokens: {' '.join(d['right_text'])}"
)
if prettify_tokens:
d['left_text'] = format_special_chars(d['left_text'])
d['right_text'] = format_special_chars(d['right_text'])
params = {
'attention': attn_data,
'default_filter': "0",
'root_div_id': vis_id,
'layer': layer,
'heads': heads,
'include_layers': include_layers
}
# require.js must be imported for Colab or JupyterLab:
if html_action == 'gradio':
html1 = HTML('<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>')
html2 = HTML(vis_html)
return {'html1': html1, 'html2' : html2, 'params': params }
if html_action == 'view':
display(HTML('<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>'))
display(HTML(vis_html))
__location__ = os.path.realpath(
os.path.join(os.getcwd(), os.path.dirname(__file__)))
vis_js = open(os.path.join(__location__, 'head_view.js')).read().replace("PYTHON_PARAMS", json.dumps(params))
display(Javascript(vis_js))
elif html_action == 'return':
html1 = HTML('<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>')
html2 = HTML(vis_html)
__location__ = os.path.realpath(
os.path.join(os.getcwd(), os.path.dirname(__file__)))
vis_js = open(os.path.join(__location__, 'head_view.js')).read().replace("PYTHON_PARAMS", json.dumps(params))
html3 = Javascript(vis_js)
script = '\n<script type="text/javascript">\n' + html3.data + '\n</script>\n'
head_html = HTML(html1.data + html2.data + script)
return head_html
else:
raise ValueError("'html_action' parameter must be 'view' or 'return")