error-analysis / app.py
nazneen's picture
viz
9ab2ff4
raw
history blame
10.3 kB
### LIBRARIES ###
# # Data
from matplotlib.pyplot import legend
import numpy as np
import pandas as pd
import torch
import json
from tqdm import tqdm
from math import floor
from datasets import load_dataset
from collections import defaultdict
from transformers import AutoTokenizer
# Analysis
# from gensim.models.doc2vec import Doc2Vec
# from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
# import nltk
# nltk.download('punkt') #make sure that punkt is downloaded
# App & Visualization
import streamlit as st
import altair as alt
import plotly.graph_objects as go
from streamlit_vega_lite import altair_component
# utils
from random import sample
from error_analysis import utils as ut
# from PIL import Image
def down_samp(embedding):
"""Down sample a data frame for altiar visualization """
# total number of positive and negative sentiments in the class
#embedding = embedding.groupby('slice').apply(lambda x: x.sample(frac=0.3))
total_size = embedding.groupby(['slice','label'], as_index=False).count()
user_data = 0
# if 'Your Sentences' in str(total_size['slice']):
# tmp = embedding.groupby(['slice'], as_index=False).count()
# val = int(tmp[tmp['slice'] == "Your Sentences"]['source'])
# user_data = val
max_sample = total_size.groupby('slice').max()['content']
# # down sample to meeting altair's max values
# # but keep the proportional representation of groups
down_samp = 1/(sum(max_sample.astype(float))/(1000-user_data))
max_samp = max_sample.apply(lambda x: floor(x*down_samp)).astype(int).to_dict()
max_samp['Your Sentences'] = user_data
# # sample down for each group in the data frame
embedding = embedding.groupby('slice').apply(lambda x: x.sample(n=max_samp.get(x.name))).reset_index(drop=True)
# # order the embedding
return(embedding)
def data_comparison(df):
# set up a dropdown select bindinf
# input_dropdown = alt.binding_select(options=['Negative Sentiment','Positive Sentiment'])
selection = alt.selection_multi(fields=['slice','label'])
color = alt.condition(alt.datum.slice == 'high-loss', alt.value("orange"), alt.value("steelblue"))
# color = alt.condition(selection,
# alt.Color('slice:Q', legend=None),
# # scale = alt.Scale(domain = pop_domain,range=color_range)),
# alt.value('lightgray'))
opacity = alt.condition(selection, alt.value(0.7), alt.value(0.25))
# basic chart
scatter = alt.Chart(df).mark_point(size=100, filled=True).encode(
x=alt.X('x', axis=None),
y=alt.Y('y', axis=None),
color=color,
shape=alt.Shape('label', scale=alt.Scale(range=['circle', 'diamond'])),
tooltip=['slice','content','label','pred'],
opacity=opacity
).properties(
width=1500,
height=1000
).interactive()
legend = alt.Chart(df).mark_point().encode(
y=alt.Y('slice:N', axis=alt.Axis(orient='right'), title="",),
x=alt.X("label"),
shape=alt.Shape('label', scale=alt.Scale(
range=['circle', 'diamond']), legend=None),
color=color
).add_selection(
selection
)
layered = legend | scatter
layered = layered.configure_axis(
grid=False
).configure_view(
strokeOpacity=0
).configure_legend(
strokeColor='gray',
fillColor='#EEEEEE',
padding=10,
cornerRadius=10,
orient='top-right'
)
return layered
def quant_panel(embedding_df):
""" Quantitative Panel Layout"""
all_metrics = {}
# st.warning("**Data Comparison**")
# with st.expander("how to read this chart:"):
# st.markdown("* each **point** is a single sentence")
# st.markdown("* the **position** of each dot is determined mathematically based upon an analysis of the words in a sentence. The **closer** two points on the visualization the **more similar** the sentences are. The **further apart ** two points on the visualization the **more different** the sentences are")
# st.markdown(
# " * the **shape** of each point reflects whether it a positive (diamond) or negative sentiment (circle)")
# st.markdown("* the **color** of each point is the ")
st.altair_chart(data_comparison(down_samp(embedding_df)))
def frequent_tokens(data, tokenizer, loss_quantile=0.95, top_k=200, smoothing=0.005):
unique_tokens = []
tokens = []
for row in tqdm(data['content']):
tokenized = tokenizer(row,padding=True, return_tensors='pt')
tokens.append(tokenized['input_ids'].flatten())
unique_tokens.append(torch.unique(tokenized['input_ids']))
losses = data['loss'].astype(float)
high_loss = losses.quantile(loss_quantile)
loss_weights = (losses > high_loss)
loss_weights = loss_weights / loss_weights.sum()
token_frequencies = defaultdict(float)
token_frequencies_error = defaultdict(float)
weights_uniform = np.full_like(loss_weights, 1 / len(loss_weights))
num_examples = len(data)
for i in tqdm(range(num_examples)):
for token in unique_tokens[i]:
token_frequencies[token.item()] += weights_uniform[i]
token_frequencies_error[token.item()] += loss_weights[i]
token_lrs = {k: (smoothing+token_frequencies_error[k]) / (smoothing+token_frequencies[k]) for k in token_frequencies}
tokens_sorted = list(map(lambda x: x[0], sorted(token_lrs.items(), key=lambda x: x[1])[::-1]))
top_tokens = []
for i, (token) in enumerate(tokens_sorted[:top_k]):
top_tokens.append(['%10s' % (tokenizer.decode(token)), '%.4f' % (token_frequencies[token]), '%.4f' % (
token_frequencies_error[token]), '%4.2f' % (token_lrs[token])])
return pd.DataFrame(top_tokens, columns=['Token', 'Freq', 'Freq error slice', 'lrs'])
@st.cache(ttl=600)
def get_data(spotlight, emb):
preds = spotlight.outputs.numpy()
losses = spotlight.losses.numpy()
embeddings = pd.DataFrame(emb, columns=['x', 'y'])
num_examples = len(losses)
# dataset_labels = [dataset[i]['label'] for i in range(num_examples)]
return pd.concat([pd.DataFrame(np.transpose(np.vstack([dataset[:num_examples]['content'],
dataset[:num_examples]['label'], preds, losses])), columns=['content', 'label', 'pred', 'loss']), embeddings], axis=1)
def topic_distribution(weights, smoothing=0.01):
topic_frequencies = defaultdict(float)
topic_frequencies_spotlight = defaultdict(float)
weights_uniform = np.full_like(weights, 1 / len(weights))
num_examples = len(weights)
for i in range(num_examples):
example = dataset[i]
category = example['title']
topic_frequencies[category] += weights_uniform[i]
topic_frequencies_spotlight[category] += weights[i]
topic_ratios = {c: (smoothing + topic_frequencies_spotlight[c]) / (
smoothing + topic_frequencies[c]) for c in topic_frequencies}
categories_sorted = map(lambda x: x[0], sorted(
topic_ratios.items(), key=lambda x: x[1], reverse=True))
topic_distr = []
for category in categories_sorted:
topic_distr.append(['%.3f' % topic_frequencies[category], '%.3f' %
topic_frequencies_spotlight[category], '%.2f' % topic_ratios[category], '%s' % category])
return pd.DataFrame(topic_distr, columns=['Overall frequency', 'Error frequency', 'Ratio', 'Category'])
# for category in categories_sorted:
# return(topic_frequencies[category], topic_frequencies_spotlight[category], topic_ratios[category], category)
if __name__ == "__main__":
### STREAMLIT APP CONGFIG ###
st.set_page_config(layout="wide", page_title="Error Slice Analysis")
ut.init_style()
lcol, rcol = st.columns([2, 3])
# ******* loading the mode and the data
with st.sidebar:
st.title('Error Analysis')
dataset = st.sidebar.selectbox(
"Dataset",
["amazon_polarity", "squad", "movielens", "waterbirds"],
index=0
)
tokenizer = AutoTokenizer.from_pretrained(
"distilbert-base-uncased-finetuned-sst-2-english")
model = st.sidebar.selectbox(
"Model",
["distilbert-base-uncased-finetuned-sst-2-english",
"distilbert-base-uncased-finetuned-sst-2-english"],
index=0
)
loss_quantile = st.sidebar.selectbox(
"Loss Quantile",
[0.98, 0.95, 0.9, 0.8, 0.75],
index = 1
)
### LOAD DATA AND SESSION VARIABLES ###
data_df = pd.read_parquet('./assets/data/amazon_polarity.test.parquet')
data_df.reset_index(drop=True, inplace=True)
embedding_umap = data_df[['x','y']]
if "user_data" not in st.session_state:
st.session_state["user_data"] = data_df
if "selected_slice" not in st.session_state:
st.session_state["selected_slice"] = None
if "embedding" not in st.session_state:
st.session_state["embedding"] = embedding_umap
with lcol:
st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
dataframe = data_df[['content', 'label', 'pred', 'loss']].sort_values(
by=['loss'], ascending=False)
table_html = dataframe.to_html(
columns=['content', 'label', 'pred', 'loss'], max_rows=100)
# table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
st.write(dataframe)
st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
commontokens = frequent_tokens(data_df, tokenizer, loss_quantile=loss_quantile)
st.write(commontokens)
# st_aggrid.AgGrid(dataframe)
# table_html = dataframe.to_html(columns=['content', 'label', 'pred', 'loss'], max_rows=100)
# table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
# st.write(table_html)
with rcol:
data_df['loss'] = data_df['loss'].astype(float)
losses = data_df['loss']
high_loss = losses.quantile(loss_quantile)
data_df['slice'] = 'high-loss'
data_df['slice'] = data_df['slice'].where(data_df['loss'] > high_loss, 'low-loss')
quant_panel(data_df)