Spaces:
Runtime error
Runtime error
from collections import Counter | |
import numpy as np | |
import pandas as pd | |
import plotly.express as px | |
import streamlit as st | |
from datasets import load_dataset | |
from matplotlib import pyplot as plt | |
from matplotlib_venn import venn2, venn3 | |
from ngram import get_tuples_manual_sentences | |
from rich import print as rprint | |
from bigbio.dataloader import BigBioConfigHelpers | |
# from matplotlib_venn_wordcloud import venn2_wordcloud, venn3_wordcloud | |
# vanilla tokenizer | |
def tokenizer(text, counter): | |
if not text: | |
return text, [] | |
text = text.strip() | |
text = text.replace("\t", "") | |
text = text.replace("\n", "") | |
# split | |
text_list = text.split(" ") | |
return text, text_list | |
def norm(lengths): | |
mu = np.mean(lengths) | |
sigma = np.std(lengths) | |
return mu, sigma | |
def load_helper(): | |
conhelps = BigBioConfigHelpers() | |
conhelps = conhelps.filtered(lambda x: x.dataset_name != "pubtator_central") | |
conhelps = conhelps.filtered(lambda x: x.is_bigbio_schema) | |
conhelps = conhelps.filtered(lambda x: not x.is_local) | |
rprint( | |
"loaded {} configs from {} datasets".format( | |
len(conhelps), | |
len(set([helper.dataset_name for helper in conhelps])), | |
) | |
) | |
return conhelps | |
_TEXT_MAPS = { | |
"bigbio_kb": ["text"], | |
"bigbio_text": ["text"], | |
"bigbio_qa": ["question", "context"], | |
"bigbio_te": ["premise", "hypothesis"], | |
"bigbio_tp": ["text_1", "text_2"], | |
"bigbio_pairs": ["text_1", "text_2"], | |
"bigbio_t2t": ["text_1", "text_2"], | |
} | |
IBM_COLORS = [ | |
"#648fff", | |
"#dc267f", | |
"#ffb000", | |
"#fe6100", | |
"#785ef0", | |
"#000000", | |
"#ffffff", | |
] | |
N = 3 | |
def token_length_per_entry(entry, schema, counter): | |
result = {} | |
if schema == "bigbio_kb": | |
for passage in entry["passages"]: | |
result_key = passage["type"] | |
for key in _TEXT_MAPS[schema]: | |
text = passage[key][0] | |
sents, ngrams = get_tuples_manual_sentences(text.lower(), N) | |
toks = [tok for sent in sents for tok in sent] | |
tups = ["_".join(tup) for tup in ngrams] | |
counter.update(tups) | |
result[result_key] = len(toks) | |
else: | |
for key in _TEXT_MAPS[schema]: | |
text = entry[key] | |
sents, ngrams = get_tuples_manual_sentences(text.lower(), N) | |
toks = [tok for sent in sents for tok in sent] | |
result[key] = len(toks) | |
tups = ["_".join(tup) for tup in ngrams] | |
counter.update(tups) | |
return result, counter | |
def parse_token_length_and_n_gram(dataset, data_config, st=None): | |
hist_data = [] | |
n_gram_counters = [] | |
rprint(data_config) | |
for split, data in dataset.items(): | |
my_bar = st.progress(0) | |
total = len(data) | |
n_gram_counter = Counter() | |
for i, entry in enumerate(data): | |
my_bar.progress(int(i / total * 100)) | |
result, n_gram_counter = token_length_per_entry( | |
entry, data_config.schema, n_gram_counter | |
) | |
result["total_token_length"] = sum([v for k, v in result.items()]) | |
result["split"] = split | |
hist_data.append(result) | |
# remove single count | |
# n_gram_counter = Counter({x: count for x, count in n_gram_counter.items() if count > 1}) | |
n_gram_counters.append(n_gram_counter) | |
my_bar.empty() | |
st.write("token lengths complete!") | |
return pd.DataFrame(hist_data), n_gram_counters | |
def center_title(fig): | |
fig.update_layout( | |
title={"y": 0.9, "x": 0.5, "xanchor": "center", "yanchor": "top"}, | |
font=dict( | |
size=18, | |
), | |
) | |
return fig | |
def draw_histogram(hist_data, col_name, st=None): | |
fig = px.histogram( | |
hist_data, | |
x=col_name, | |
color="split", | |
color_discrete_sequence=IBM_COLORS, | |
marginal="box", # or violin, rug | |
barmode="group", | |
hover_data=hist_data.columns, | |
histnorm="probability", | |
nbins=20, | |
title=f"{col_name} distribution by split", | |
) | |
st.plotly_chart(center_title(fig), use_container_width=True) | |
def draw_bar(bar_data, x, y, st=None): | |
fig = px.bar( | |
bar_data, | |
x=x, | |
y=y, | |
color="split", | |
color_discrete_sequence=IBM_COLORS, | |
# marginal="box", # or violin, rug | |
barmode="group", | |
hover_data=bar_data.columns, | |
title=f"{y} distribution by split", | |
) | |
st.plotly_chart(center_title(fig), use_container_width=True) | |
def parse_metrics(metadata, st=None): | |
for k, m in metadata.items(): | |
mattrs = m.__dict__ | |
for m, attr in mattrs.items(): | |
if type(attr) == int and attr > 0: | |
st.metric(label=f"{k}-{m}", value=attr) | |
def parse_counters(metadata): | |
metadata = metadata["train"] # using the training counter to fetch the names | |
counters = [] | |
for k, v in metadata.__dict__.items(): | |
if "counter" in k and len(v) > 0: | |
counters.append(k) | |
return counters | |
# generate the df for histogram | |
def parse_label_counter(metadata, counter_type): | |
hist_data = [] | |
for split, m in metadata.items(): | |
metadata_counter = getattr(m, counter_type) | |
for k, v in metadata_counter.items(): | |
row = {} | |
row["labels"] = k | |
row[counter_type] = v | |
row["split"] = split | |
hist_data.append(row) | |
return pd.DataFrame(hist_data) | |
if __name__ == "__main__": | |
# load helpers | |
conhelps = load_helper() | |
configs_set = set() | |
for conhelper in conhelps: | |
configs_set.add(conhelper.dataset_name) | |
# st.write(sorted(configs_set)) | |
# setup page, sidebar, columns | |
st.set_page_config(layout="wide") | |
s = st.session_state | |
if not s: | |
s.pressed_first_button = False | |
data_name = st.sidebar.selectbox("dataset", sorted(configs_set)) | |
st.sidebar.write("you selected:", data_name) | |
st.header(f"Dataset stats for {data_name}") | |
# setup data configs | |
data_helpers = conhelps.for_dataset(data_name) | |
data_configs = [d.config for d in data_helpers] | |
data_config_names = [d.config.name for d in data_helpers] | |
data_config_name = st.sidebar.selectbox("config", set(data_config_names)) | |
if st.sidebar.button("fetch") or s.pressed_first_button: | |
s.pressed_first_button = True | |
helper = conhelps.for_config_name(data_config_name) | |
metadata_helper = helper.get_metadata() | |
parse_metrics(metadata_helper, st.sidebar) | |
# load HF dataset | |
data_idx = data_config_names.index(data_config_name) | |
data_config = data_configs[data_idx] | |
# st.write(data_name) | |
dataset = load_dataset( | |
f"bigbio/{data_name}", name=data_config_name | |
) | |
ds = pd.DataFrame(dataset["train"]) | |
st.write(ds) | |
# general token length | |
tok_hist_data, ngram_counters = parse_token_length_and_n_gram( | |
dataset, data_config, st.sidebar | |
) | |
# draw token distribution | |
draw_histogram(tok_hist_data, "total_token_length", st) | |
# general counter(s) | |
col1, col2 = st.columns([1, 6]) | |
counters = parse_counters(metadata_helper) | |
counter_type = col1.selectbox("counter_type", counters) | |
label_df = parse_label_counter(metadata_helper, counter_type) | |
label_max = int(label_df[counter_type].max() - 1) | |
label_min = int(label_df[counter_type].min()) | |
filter_value = col1.slider("counter_filter (min, max)", label_min, label_max) | |
label_df = label_df[label_df[counter_type] >= filter_value] | |
# draw bar chart for counter | |
draw_bar(label_df, "labels", counter_type, col2) | |
venn_fig, ax = plt.subplots() | |
if len(ngram_counters) == 2: | |
union_counter = ngram_counters[0] + ngram_counters[1] | |
print(ngram_counters[0].most_common(10)) | |
print(ngram_counters[1].most_common(10)) | |
total = len(union_counter.keys()) | |
ngram_counter_sets = [ | |
set(ngram_counter.keys()) for ngram_counter in ngram_counters | |
] | |
venn2( | |
ngram_counter_sets, | |
dataset.keys(), | |
set_colors=IBM_COLORS[:3], | |
subset_label_formatter=lambda x: f"{(x/total):1.0%}", | |
) | |
else: | |
union_counter = ngram_counters[0] + ngram_counters[1] + ngram_counters[2] | |
total = len(union_counter.keys()) | |
ngram_counter_sets = [ | |
set(ngram_counter.keys()) for ngram_counter in ngram_counters | |
] | |
venn3( | |
ngram_counter_sets, | |
dataset.keys(), | |
set_colors=IBM_COLORS[:4], | |
subset_label_formatter=lambda x: f"{(x/total):1.0%}", | |
) | |
venn_fig.suptitle(f"{N}-gram intersection for {data_name}", fontsize=20) | |
st.pyplot(venn_fig) | |
st.sidebar.button("Re-run") | |