ExplaiNER / src /subpages /raw_data.py
ceyda's picture
Duplicate from aseifert/ExplaiNER
2d4811a
"""See the data as seen by your model."""
import pandas as pd
import streamlit as st
from src.subpages.page import Context, Page
from src.utils import aggrid_interactive_table
@st.cache
def convert_df(df):
return df.to_csv().encode("utf-8")
class RawDataPage(Page):
name = "Raw data"
icon = "qr-code"
def render(self, context: Context):
st.title(self.name)
with st.expander("💡", expanded=True):
st.write("See the data as seen by your model.")
st.subheader("Dataset")
st.code(
f"Dataset: {context.ds_name}\nConfig: {context.ds_config_name}\nSplit: {context.ds_split_name}"
)
st.write("**Data after processing and inference**")
processed_df = (
context.df_tokens.drop("hidden_states", axis=1).drop("attention_mask", axis=1).round(3)
)
cols = (
"ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split()
)
if "token_type_ids" not in processed_df.columns:
cols.remove("token_type_ids")
processed_df = processed_df[cols]
aggrid_interactive_table(processed_df)
processed_df_csv = convert_df(processed_df)
st.download_button(
"Download csv",
processed_df_csv,
"processed_data.csv",
"text/csv",
)
st.write("**Raw data (exploded by tokens)**")
raw_data_df = context.split.to_pandas().apply(pd.Series.explode) # type: ignore
aggrid_interactive_table(raw_data_df)
raw_data_df_csv = convert_df(raw_data_df)
st.download_button(
"Download csv",
raw_data_df_csv,
"raw_data.csv",
"text/csv",
)