bigbio_test / vis_data_card.py
tensorized
testing scitail
50e5fc3
# from matplotlib_venn import venn2, venn3
import json
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
from datasets import load_dataset
from plotly.subplots import make_subplots
from rich import print as rprint
from collections import Counter
from ngram import get_tuples_manual_sentences
from bigbio.dataloader import BigBioConfigHelpers
import sys
pio.kaleido.scope.mathjax = None
# vanilla tokenizer
def tokenizer(text, counter):
if not text:
return text, []
text = text.strip()
text = text.replace("\t", "")
text = text.replace("\n", "")
# split
text_list = text.split(" ")
return text, text_list
def norm(lengths):
mu = np.mean(lengths)
sigma = np.std(lengths)
return mu, sigma
def load_helper(local=""):
if local != "":
with open(local, "r") as file:
conhelps = json.load(file)
else:
conhelps = BigBioConfigHelpers()
conhelps = conhelps.filtered(lambda x: x.dataset_name != "pubtator_central")
conhelps = conhelps.filtered(lambda x: x.is_bigbio_schema)
conhelps = conhelps.filtered(lambda x: not x.is_local)
rprint(
"loaded {} configs from {} datasets".format(
len(conhelps),
len(set([helper.dataset_name for helper in conhelps])),
)
)
return conhelps
_TEXT_MAPS = {
"bigbio_kb": ["text"],
"bigbio_text": ["text"],
"bigbio_qa": ["question", "context"],
"bigbio_te": ["premise", "hypothesis"],
"bigbio_tp": ["text_1", "text_2"],
"bigbio_pairs": ["text_1", "text_2"],
"bigbio_t2t": ["text_1", "text_2"],
}
IBM_COLORS = [
"#648fff", # train
"#dc267f", # val
"#ffb000", # test
"#fe6100",
"#785ef0",
"#000000",
"#ffffff",
]
SPLIT_COLOR_MAP = {
"train": "#648fff",
"validation": "#dc267f",
"test": "#ffb000",
}
N = 3
def token_length_per_entry(entry, schema, counter):
result = {}
entry_id = entry['id']
if schema == "bigbio_kb":
for passage in entry["passages"]:
result_key = passage["type"]
for key in _TEXT_MAPS[schema]:
text = passage[key][0]
if not text:
print(f"WARNING: text key does not exist: entry {entry_id}")
result["token_length"] = 0
result["text_type"] = result_key
continue
sents, ngrams = get_tuples_manual_sentences(text.lower(), N)
toks = [tok for sent in sents for tok in sent]
tups = ["_".join(tup) for tup in ngrams]
counter.update(tups)
result["token_length"] = len(toks)
result["text_type"] = result_key
else:
for key in _TEXT_MAPS[schema]:
text = entry[key]
if not text:
print(f"WARNING: text key does not exist, entry {entry_id}")
result["token_length"] = 0
result["text_type"] = key
continue
else:
sents, ngrams = get_tuples_manual_sentences(text.lower(), N)
toks = [tok for sent in sents for tok in sent]
result["token_length"] = len(toks)
result["text_type"] = key
tups = ["_".join(tup) for tup in ngrams]
counter.update(tups)
return result, counter
def parse_token_length_and_n_gram(dataset, schema_type):
hist_data = []
n_gram_counters = []
for split, data in dataset.items():
n_gram_counter = Counter()
for i, entry in enumerate(data):
result, n_gram_counter = token_length_per_entry(
entry, schema_type, n_gram_counter
)
result["split"] = split
hist_data.append(result)
n_gram_counters.append(n_gram_counter)
return pd.DataFrame(hist_data), n_gram_counters
def resolve_splits(df_split):
official_splits = set(df_split).intersection(set(SPLIT_COLOR_MAP.keys()))
return official_splits
def draw_box(df, col_name, row, col, fig):
splits = resolve_splits(df["split"].unique())
for split in splits:
split_count = df.loc[df["split"] == split, col_name].tolist()
print(split)
fig.add_trace(
go.Box(
x=split_count,
name=split,
marker_color=SPLIT_COLOR_MAP[split.split("_")[0]],
),
row=row,
col=col,
)
def draw_bar(df, col_name, y_name, row, col, fig):
splits = resolve_splits(df["split"].unique())
for split in splits:
split_count = df.loc[df["split"] == split, col_name].tolist()
y_list = df.loc[df["split"] == split, y_name].tolist()
fig.add_trace(
go.Bar(
x=split_count,
y=y_list,
name=split,
marker_color=SPLIT_COLOR_MAP[split.split("_")[0]],
showlegend=False,
),
row=row,
col=col,
)
fig.update_traces(orientation="h") # horizontal box plots
def parse_counters(metadata):
metadata = metadata[
list(metadata.keys())[0]
] # using the training counter to fetch the names
counters = []
for k, v in metadata.__dict__.items():
if "counter" in k and len(v) > 0:
counters.append(k)
return counters
# generate the df for histogram
def parse_label_counter(metadata, counter_type):
hist_data = []
for split, m in metadata.items():
metadata_counter = getattr(m, counter_type)
for k, v in metadata_counter.items():
row = {}
row["labels"] = k
row[counter_type] = v
row["split"] = split
hist_data.append(row)
return pd.DataFrame(hist_data)
def gen_latex(dataset_name, helper, splits, schemas, fig_path):
if type(helper.description) is dict:
# TODO hacky, change this to include all decsriptions
descriptions = helper.description[list(helper.description.keys())[0]]
else:
descriptions = helper.description
descriptions = descriptions.replace("\n", "").replace("\t", "")
langs = [l.value for l in helper.languages]
languages = " ".join(langs)
if type(helper.license) is dict:
license = helper.license.value.name
else:
license = helper.license.name
tasks = [" ".join(t.name.lower().split("_")) for t in helper.tasks]
tasks = ", ".join(tasks)
schemas = " ".join([r"{\tt "] + list(schemas) + ["}"]) # TODO \tt
splits = ", ".join(list(splits))
data_name_display = " ".join(data_name.split("_"))
latex_bod = r"\clearpage" + "\n" + r"\section*{" + fr"{data_name_display}" + " Data Card" + r"}" + "\n"
latex_bod += (
r"\begin{figure}[ht!]"
+ "\n"
+ r"\centering"
+ "\n"
+ r"\includegraphics[width=\linewidth]{"
)
latex_bod += f"{fig_path}" + r"}" + "\n"
latex_bod += r"\caption{\label{fig:"
latex_bod += fr"{data_name}" + r"}"
latex_bod += (
r"Token frequency distribution by split (top) and frequency of different kind of instances (bottom).}"
+ "\n"
)
latex_bod += r"\end{figure}" + "\n" + r"\textbf{Dataset Description} "
latex_bod += (
fr"{descriptions}"
+ "\n"
+ r"\textbf{Homepage:} "
+ f"{helper.homepage}"
+ "\n"
+ r"\textbf{URL:} "
+ f"{helper.homepage}" # TODO change this later
+ "\n"
+ r"\textbf{Licensing:} "
+ f"{license}"
+ "\n"
+ r"\textbf{Languages:} "
+ f"{languages}"
+ "\n"
+ r"\textbf{Tasks:} "
+ f"{tasks}"
+ "\n"
+ r"\textbf{Schemas:} "
+ f"{schemas}"
+ "\n"
+ r"\textbf{Splits:} "
+ f"{splits}"
)
return latex_bod
def write_latex(latex_body, latex_name):
text_file = open(f"tex/{latex_name}", "w")
text_file.write(latex_body)
text_file.close()
def draw_figure(data_name, data_config_name, schema_type):
helper = conhelps.for_config_name(data_config_name)
metadata_helper = helper.get_metadata() # calls load_dataset for meta parsing
rprint(metadata_helper)
splits = metadata_helper.keys()
# calls HF load_dataset _again_ for token parsing
dataset = load_dataset(
f"bigbio/biodatasets/{data_name}/{data_name}.py", name=data_config_name
)
# general token length
tok_hist_data, ngram_counters = parse_token_length_and_n_gram(dataset, schema_type)
rprint(helper)
# general counter(s)
# TODO generate the pdf and fix latex
counters = parse_counters(metadata_helper)
print(counters)
rows = len(counters) // 3
if len(counters) >= 3:
# counters = counters[:3]
cols = 3
specs = [[{"colspan": 3}, None, None]] + [[{}, {}, {}]] * (rows + 1)
elif len(counters) == 1:
specs = [[{}], [{}]]
cols = 1
elif len(counters) == 2:
specs = [[{"colspan": 2}, None]] + [[{}, {}]] * (rows + 1)
cols = 2
counters.sort()
counter_titles = ["Label Counts by Type: " + ct.split("_")[0] for ct in counters]
titles = ("token length",) + tuple(counter_titles)
# Make figure with subplots
fig = make_subplots(
rows=rows + 2,
cols=cols,
subplot_titles=titles,
specs=specs,
vertical_spacing=0.10,
horizontal_spacing=0.10,
)
# draw token distribution
if "token_length" in tok_hist_data.keys():
draw_box(tok_hist_data, "token_length", row=1, col=1, fig=fig)
for i, ct in enumerate(counters):
row = i // 3 + 2
col = i % 3 + 1
label_df = parse_label_counter(metadata_helper, ct)
label_min = int(label_df[ct].min())
# filter_value = int((label_max - label_min) * 0.01 + label_min)
label_df = label_df[label_df[ct] >= label_min]
print(label_df.head(5))
# draw bar chart for counter
draw_bar(label_df, ct, "labels", row=row, col=col, fig=fig)
fig.update_annotations(font_size=12)
fig.update_layout(
margin=dict(l=25, r=25, t=25, b=25, pad=2),
# showlegend=False,
# title_text=data_name,
height=600,
width=1000,
)
# fig.show()
fig_name = f"{data_name}_{data_config_name}.pdf"
fig_path = f"figures/data_card/{fig_name}"
fig.write_image(fig_path)
dataset.cleanup_cache_files()
return helper, splits, fig_path
if __name__ == "__main__":
# load helpers
# each entry in local metadata is the dataset name
dc_local = load_helper(local="scripts/bigbio-public-metadatas-6-8.json")
# each entry is the config
conhelps = load_helper()
dc = list()
# TODO uncomment this
# for conhelper in conhelps:
# # print(f"{conhelper.dataset_name}-{conhelper.config.subset_id}-{conhelper.config.schema}")
# dc.append(conhelper.dataset_name)
# datacard per data, metadata chart per config
# for data_name, meta in dc_local.items():
# config_metas = meta['config_metas']
# config_metas_keys = config_metas.keys()
# if len(config_metas_keys) > 1:
# print(f'dataset {data_name} has more than one config')
# schemas = set()
# for config_name, config in config_metas.items():
# bigbio_schema = config['bigbio_schema']
# helper, splits, fig_path = draw_figure(data_name, config_name, bigbio_schema)
# schemas.add(helper.bigbio_schema_caps)
# latex_bod = gen_latex(data_name, helper, splits, schemas, fig_path)
# latex_name = f"{data_name}_{config_name}.tex"
# write_latex(latex_bod, latex_name)
# print(latex_bod)
# TODO try this code first, then use this for the whole loop
# skipped medal, too large, no nagel/pcr/pubtator_central/spl_adr_200db in local
data_name = sys.argv[1]
schemas = set()
# LOCAL
# meta = dc_local[data_name]
# config_metas = meta['config_metas']
# config_metas_keys = config_metas.keys()
# if len(config_metas_keys) >= 1:
# print(f'dataset {data_name} has more than one config')
# for config_name, config in config_metas.items():
# bigbio_schema = config['bigbio_schema']
# helper, splits, fig_path = draw_figure(data_name, config_name, bigbio_schema)
# schemas.add(helper.bigbio_schema_caps)
# latex_bod = gen_latex(data_name, helper, splits, schemas, fig_path)
# latex_name = f"{data_name}_{config_name}.tex"
# write_latex(latex_bod, latex_name)
# print(latex_bod)
# NON LOCAL
config_helpers = conhelps.for_dataset(data_name)
for config_helper in config_helpers:
rprint(config_helper)
bigbio_schema = config_helper.config.schema
config_name = config_helper.config.name
helper, splits, fig_path = draw_figure(data_name, config_name, bigbio_schema)
schemas.add(helper.bigbio_schema_caps)
latex_bod = gen_latex(data_name, helper, splits, schemas, fig_path)
latex_name = f"{data_name}_{config_name}.tex"
write_latex(latex_bod, latex_name)
print(latex_bod)