from collections import Counter import numpy as np import pandas as pd import plotly.express as px import streamlit as st from datasets import load_dataset from matplotlib import pyplot as plt from matplotlib_venn import venn2, venn3 from ngram import get_tuples_manual_sentences from rich import print as rprint from bigbio.dataloader import BigBioConfigHelpers # from matplotlib_venn_wordcloud import venn2_wordcloud, venn3_wordcloud # vanilla tokenizer def tokenizer(text, counter): if not text: return text, [] text = text.strip() text = text.replace("\t", "") text = text.replace("\n", "") # split text_list = text.split(" ") return text, text_list def norm(lengths): mu = np.mean(lengths) sigma = np.std(lengths) return mu, sigma def load_helper(): conhelps = BigBioConfigHelpers() conhelps = conhelps.filtered(lambda x: x.dataset_name != "pubtator_central") conhelps = conhelps.filtered(lambda x: x.is_bigbio_schema) conhelps = conhelps.filtered(lambda x: not x.is_local) rprint( "loaded {} configs from {} datasets".format( len(conhelps), len(set([helper.dataset_name for helper in conhelps])), ) ) return conhelps _TEXT_MAPS = { "bigbio_kb": ["text"], "bigbio_text": ["text"], "bigbio_qa": ["question", "context"], "bigbio_te": ["premise", "hypothesis"], "bigbio_tp": ["text_1", "text_2"], "bigbio_pairs": ["text_1", "text_2"], "bigbio_t2t": ["text_1", "text_2"], } IBM_COLORS = [ "#648fff", "#dc267f", "#ffb000", "#fe6100", "#785ef0", "#000000", "#ffffff", ] N = 3 def token_length_per_entry(entry, schema, counter): result = {} if schema == "bigbio_kb": for passage in entry["passages"]: result_key = passage["type"] for key in _TEXT_MAPS[schema]: text = passage[key][0] sents, ngrams = get_tuples_manual_sentences(text.lower(), N) toks = [tok for sent in sents for tok in sent] tups = ["_".join(tup) for tup in ngrams] counter.update(tups) result[result_key] = len(toks) else: for key in _TEXT_MAPS[schema]: text = entry[key] sents, ngrams = get_tuples_manual_sentences(text.lower(), N) toks = [tok for sent in sents for tok in sent] result[key] = len(toks) tups = ["_".join(tup) for tup in ngrams] counter.update(tups) return result, counter def parse_token_length_and_n_gram(dataset, data_config, st=None): hist_data = [] n_gram_counters = [] rprint(data_config) for split, data in dataset.items(): my_bar = st.progress(0) total = len(data) n_gram_counter = Counter() for i, entry in enumerate(data): my_bar.progress(int(i / total * 100)) result, n_gram_counter = token_length_per_entry( entry, data_config.schema, n_gram_counter ) result["total_token_length"] = sum([v for k, v in result.items()]) result["split"] = split hist_data.append(result) # remove single count # n_gram_counter = Counter({x: count for x, count in n_gram_counter.items() if count > 1}) n_gram_counters.append(n_gram_counter) my_bar.empty() st.write("token lengths complete!") return pd.DataFrame(hist_data), n_gram_counters def center_title(fig): fig.update_layout( title={"y": 0.9, "x": 0.5, "xanchor": "center", "yanchor": "top"}, font=dict( size=18, ), ) return fig def draw_histogram(hist_data, col_name, st=None): fig = px.histogram( hist_data, x=col_name, color="split", color_discrete_sequence=IBM_COLORS, marginal="box", # or violin, rug barmode="group", hover_data=hist_data.columns, histnorm="probability", nbins=20, title=f"{col_name} distribution by split", ) st.plotly_chart(center_title(fig), use_container_width=True) def draw_bar(bar_data, x, y, st=None): fig = px.bar( bar_data, x=x, y=y, color="split", color_discrete_sequence=IBM_COLORS, # marginal="box", # or violin, rug barmode="group", hover_data=bar_data.columns, title=f"{y} distribution by split", ) st.plotly_chart(center_title(fig), use_container_width=True) def parse_metrics(metadata, st=None): for k, m in metadata.items(): mattrs = m.__dict__ for m, attr in mattrs.items(): if type(attr) == int and attr > 0: st.metric(label=f"{k}-{m}", value=attr) def parse_counters(metadata): metadata = metadata["train"] # using the training counter to fetch the names counters = [] for k, v in metadata.__dict__.items(): if "counter" in k and len(v) > 0: counters.append(k) return counters # generate the df for histogram def parse_label_counter(metadata, counter_type): hist_data = [] for split, m in metadata.items(): metadata_counter = getattr(m, counter_type) for k, v in metadata_counter.items(): row = {} row["labels"] = k row[counter_type] = v row["split"] = split hist_data.append(row) return pd.DataFrame(hist_data) if __name__ == "__main__": # load helpers conhelps = load_helper() configs_set = set() for conhelper in conhelps: configs_set.add(conhelper.dataset_name) # st.write(sorted(configs_set)) # setup page, sidebar, columns st.set_page_config(layout="wide") s = st.session_state if not s: s.pressed_first_button = False data_name = st.sidebar.selectbox("dataset", sorted(configs_set)) st.sidebar.write("you selected:", data_name) st.header(f"Dataset stats for {data_name}") # setup data configs data_helpers = conhelps.for_dataset(data_name) data_configs = [d.config for d in data_helpers] data_config_names = [d.config.name for d in data_helpers] data_config_name = st.sidebar.selectbox("config", set(data_config_names)) if st.sidebar.button("fetch") or s.pressed_first_button: s.pressed_first_button = True helper = conhelps.for_config_name(data_config_name) metadata_helper = helper.get_metadata() parse_metrics(metadata_helper, st.sidebar) # load HF dataset data_idx = data_config_names.index(data_config_name) data_config = data_configs[data_idx] # st.write(data_name) dataset = load_dataset( f"bigbio/{data_name}", name=data_config_name ) ds = pd.DataFrame(dataset["train"]) st.write(ds) # general token length tok_hist_data, ngram_counters = parse_token_length_and_n_gram( dataset, data_config, st.sidebar ) # draw token distribution draw_histogram(tok_hist_data, "total_token_length", st) # general counter(s) col1, col2 = st.columns([1, 6]) counters = parse_counters(metadata_helper) counter_type = col1.selectbox("counter_type", counters) label_df = parse_label_counter(metadata_helper, counter_type) label_max = int(label_df[counter_type].max() - 1) label_min = int(label_df[counter_type].min()) filter_value = col1.slider("counter_filter (min, max)", label_min, label_max) label_df = label_df[label_df[counter_type] >= filter_value] # draw bar chart for counter draw_bar(label_df, "labels", counter_type, col2) venn_fig, ax = plt.subplots() if len(ngram_counters) == 2: union_counter = ngram_counters[0] + ngram_counters[1] print(ngram_counters[0].most_common(10)) print(ngram_counters[1].most_common(10)) total = len(union_counter.keys()) ngram_counter_sets = [ set(ngram_counter.keys()) for ngram_counter in ngram_counters ] venn2( ngram_counter_sets, dataset.keys(), set_colors=IBM_COLORS[:3], subset_label_formatter=lambda x: f"{(x/total):1.0%}", ) else: union_counter = ngram_counters[0] + ngram_counters[1] + ngram_counters[2] total = len(union_counter.keys()) ngram_counter_sets = [ set(ngram_counter.keys()) for ngram_counter in ngram_counters ] venn3( ngram_counter_sets, dataset.keys(), set_colors=IBM_COLORS[:4], subset_label_formatter=lambda x: f"{(x/total):1.0%}", ) venn_fig.suptitle(f"{N}-gram intersection for {data_name}", fontsize=20) st.pyplot(venn_fig) st.sidebar.button("Re-run")