Spaces:
Runtime error
Runtime error
### LIBRARIES ### | |
# # Data | |
from matplotlib.pyplot import legend | |
import numpy as np | |
import pandas as pd | |
import torch | |
import json | |
from tqdm import tqdm | |
from math import floor | |
from datasets import load_dataset | |
from collections import defaultdict | |
from transformers import AutoTokenizer | |
# Analysis | |
# from gensim.models.doc2vec import Doc2Vec | |
# from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score | |
# import nltk | |
# nltk.download('punkt') #make sure that punkt is downloaded | |
# App & Visualization | |
import streamlit as st | |
import altair as alt | |
import plotly.graph_objects as go | |
from streamlit_vega_lite import altair_component | |
# utils | |
from random import sample | |
from error_analysis import utils as ut | |
# from PIL import Image | |
def down_samp(embedding): | |
"""Down sample a data frame for altiar visualization """ | |
# total number of positive and negative sentiments in the class | |
#embedding = embedding.groupby('slice').apply(lambda x: x.sample(frac=0.3)) | |
total_size = embedding.groupby(['slice','label'], as_index=False).count() | |
user_data = 0 | |
# if 'Your Sentences' in str(total_size['slice']): | |
# tmp = embedding.groupby(['slice'], as_index=False).count() | |
# val = int(tmp[tmp['slice'] == "Your Sentences"]['source']) | |
# user_data = val | |
max_sample = total_size.groupby('slice').max()['content'] | |
# # down sample to meeting altair's max values | |
# # but keep the proportional representation of groups | |
down_samp = 1/(sum(max_sample.astype(float))/(1000-user_data)) | |
max_samp = max_sample.apply(lambda x: floor(x*down_samp)).astype(int).to_dict() | |
max_samp['Your Sentences'] = user_data | |
# # sample down for each group in the data frame | |
embedding = embedding.groupby('slice').apply(lambda x: x.sample(n=max_samp.get(x.name))).reset_index(drop=True) | |
# # order the embedding | |
return(embedding) | |
def data_comparison(df): | |
# set up a dropdown select bindinf | |
# input_dropdown = alt.binding_select(options=['Negative Sentiment','Positive Sentiment']) | |
selection = alt.selection_multi(fields=['slice','label']) | |
color = alt.condition(alt.datum.slice == 'high-loss', alt.value("orange"), alt.value("steelblue")) | |
# color = alt.condition(selection, | |
# alt.Color('slice:Q', legend=None), | |
# # scale = alt.Scale(domain = pop_domain,range=color_range)), | |
# alt.value('lightgray')) | |
opacity = alt.condition(selection, alt.value(0.7), alt.value(0.25)) | |
# basic chart | |
scatter = alt.Chart(df).mark_point(size=100, filled=True).encode( | |
x=alt.X('x', axis=None), | |
y=alt.Y('y', axis=None), | |
color=color, | |
shape=alt.Shape('label', scale=alt.Scale(range=['circle', 'diamond'])), | |
tooltip=['slice','content','label','pred'], | |
opacity=opacity | |
).properties( | |
width=1500, | |
height=1000 | |
).interactive() | |
legend = alt.Chart(df).mark_point().encode( | |
y=alt.Y('slice:N', axis=alt.Axis(orient='right'), title="",), | |
x=alt.X("label"), | |
shape=alt.Shape('label', scale=alt.Scale( | |
range=['circle', 'diamond']), legend=None), | |
color=color | |
).add_selection( | |
selection | |
) | |
layered = legend | scatter | |
layered = layered.configure_axis( | |
grid=False | |
).configure_view( | |
strokeOpacity=0 | |
).configure_legend( | |
strokeColor='gray', | |
fillColor='#EEEEEE', | |
padding=10, | |
cornerRadius=10, | |
orient='top-right' | |
) | |
return layered | |
def quant_panel(embedding_df): | |
""" Quantitative Panel Layout""" | |
all_metrics = {} | |
# st.warning("**Data Comparison**") | |
# with st.expander("how to read this chart:"): | |
# st.markdown("* each **point** is a single sentence") | |
# st.markdown("* the **position** of each dot is determined mathematically based upon an analysis of the words in a sentence. The **closer** two points on the visualization the **more similar** the sentences are. The **further apart ** two points on the visualization the **more different** the sentences are") | |
# st.markdown( | |
# " * the **shape** of each point reflects whether it a positive (diamond) or negative sentiment (circle)") | |
# st.markdown("* the **color** of each point is the ") | |
st.altair_chart(data_comparison(down_samp(embedding_df))) | |
def frequent_tokens(data, tokenizer, loss_quantile=0.95, top_k=200, smoothing=0.005): | |
unique_tokens = [] | |
tokens = [] | |
for row in tqdm(data['content']): | |
tokenized = tokenizer(row,padding=True, return_tensors='pt') | |
tokens.append(tokenized['input_ids'].flatten()) | |
unique_tokens.append(torch.unique(tokenized['input_ids'])) | |
losses = data['loss'].astype(float) | |
high_loss = losses.quantile(loss_quantile) | |
loss_weights = (losses > high_loss) | |
loss_weights = loss_weights / loss_weights.sum() | |
token_frequencies = defaultdict(float) | |
token_frequencies_error = defaultdict(float) | |
weights_uniform = np.full_like(loss_weights, 1 / len(loss_weights)) | |
num_examples = len(data) | |
for i in tqdm(range(num_examples)): | |
for token in unique_tokens[i]: | |
token_frequencies[token.item()] += weights_uniform[i] | |
token_frequencies_error[token.item()] += loss_weights[i] | |
token_lrs = {k: (smoothing+token_frequencies_error[k]) / (smoothing+token_frequencies[k]) for k in token_frequencies} | |
tokens_sorted = list(map(lambda x: x[0], sorted(token_lrs.items(), key=lambda x: x[1])[::-1])) | |
top_tokens = [] | |
for i, (token) in enumerate(tokens_sorted[:top_k]): | |
top_tokens.append(['%10s' % (tokenizer.decode(token)), '%.4f' % (token_frequencies[token]), '%.4f' % ( | |
token_frequencies_error[token]), '%4.2f' % (token_lrs[token])]) | |
return pd.DataFrame(top_tokens, columns=['Token', 'Freq', 'Freq error slice', 'lrs']) | |
def get_data(spotlight, emb): | |
preds = spotlight.outputs.numpy() | |
losses = spotlight.losses.numpy() | |
embeddings = pd.DataFrame(emb, columns=['x', 'y']) | |
num_examples = len(losses) | |
# dataset_labels = [dataset[i]['label'] for i in range(num_examples)] | |
return pd.concat([pd.DataFrame(np.transpose(np.vstack([dataset[:num_examples]['content'], | |
dataset[:num_examples]['label'], preds, losses])), columns=['content', 'label', 'pred', 'loss']), embeddings], axis=1) | |
def topic_distribution(weights, smoothing=0.01): | |
topic_frequencies = defaultdict(float) | |
topic_frequencies_spotlight = defaultdict(float) | |
weights_uniform = np.full_like(weights, 1 / len(weights)) | |
num_examples = len(weights) | |
for i in range(num_examples): | |
example = dataset[i] | |
category = example['title'] | |
topic_frequencies[category] += weights_uniform[i] | |
topic_frequencies_spotlight[category] += weights[i] | |
topic_ratios = {c: (smoothing + topic_frequencies_spotlight[c]) / ( | |
smoothing + topic_frequencies[c]) for c in topic_frequencies} | |
categories_sorted = map(lambda x: x[0], sorted( | |
topic_ratios.items(), key=lambda x: x[1], reverse=True)) | |
topic_distr = [] | |
for category in categories_sorted: | |
topic_distr.append(['%.3f' % topic_frequencies[category], '%.3f' % | |
topic_frequencies_spotlight[category], '%.2f' % topic_ratios[category], '%s' % category]) | |
return pd.DataFrame(topic_distr, columns=['Overall frequency', 'Error frequency', 'Ratio', 'Category']) | |
# for category in categories_sorted: | |
# return(topic_frequencies[category], topic_frequencies_spotlight[category], topic_ratios[category], category) | |
if __name__ == "__main__": | |
### STREAMLIT APP CONGFIG ### | |
st.set_page_config(layout="wide", page_title="Error Slice Analysis") | |
ut.init_style() | |
lcol, rcol = st.columns([2, 3]) | |
# ******* loading the mode and the data | |
with st.sidebar: | |
st.title('Error Analysis') | |
dataset = st.sidebar.selectbox( | |
"Dataset", | |
["amazon_polarity", "squad", "movielens", "waterbirds"], | |
index=0 | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"distilbert-base-uncased-finetuned-sst-2-english") | |
model = st.sidebar.selectbox( | |
"Model", | |
["distilbert-base-uncased-finetuned-sst-2-english", | |
"distilbert-base-uncased-finetuned-sst-2-english"], | |
index=0 | |
) | |
loss_quantile = st.sidebar.selectbox( | |
"Loss Quantile", | |
[0.98, 0.95, 0.9, 0.8, 0.75], | |
index = 1 | |
) | |
### LOAD DATA AND SESSION VARIABLES ### | |
data_df = pd.read_parquet('./assets/data/amazon_polarity.test.parquet') | |
data_df.reset_index(drop=True, inplace=True) | |
embedding_umap = data_df[['x','y']] | |
if "user_data" not in st.session_state: | |
st.session_state["user_data"] = data_df | |
if "selected_slice" not in st.session_state: | |
st.session_state["selected_slice"] = None | |
if "embedding" not in st.session_state: | |
st.session_state["embedding"] = embedding_umap | |
with lcol: | |
st.title('Error Slices') | |
dataframe = data_df[['content', 'label', 'pred', 'loss']].sort_values( | |
by=['loss'], ascending=False) | |
table_html = dataframe.to_html( | |
columns=['content', 'label', 'pred', 'loss'], max_rows=100) | |
# table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers | |
st.write(dataframe) | |
st.title('Word Distribution in Error Slice') | |
commontokens = frequent_tokens(data_df, tokenizer, loss_quantile=loss_quantile) | |
st.write(commontokens) | |
# st_aggrid.AgGrid(dataframe) | |
# table_html = dataframe.to_html(columns=['content', 'label', 'pred', 'loss'], max_rows=100) | |
# table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers | |
# st.write(table_html) | |
with rcol: | |
data_df['loss'] = data_df['loss'].astype(float) | |
losses = data_df['loss'] | |
high_loss = losses.quantile(loss_quantile) | |
data_df['slice'] = 'high-loss' | |
data_df['slice'] = data_df['slice'].where(data_df['loss'] > high_loss, 'low-loss') | |
quant_panel(data_df) | |