|
"""See the data as seen by your model.""" |
|
import pandas as pd |
|
import streamlit as st |
|
|
|
from src.subpages.page import Context, Page |
|
from src.utils import aggrid_interactive_table |
|
|
|
|
|
@st.cache |
|
def convert_df(df): |
|
return df.to_csv().encode("utf-8") |
|
|
|
|
|
class RawDataPage(Page): |
|
name = "Raw data" |
|
icon = "qr-code" |
|
|
|
def render(self, context: Context): |
|
st.title(self.name) |
|
with st.expander("💡", expanded=True): |
|
st.write("See the data as seen by your model.") |
|
|
|
st.subheader("Dataset") |
|
st.code( |
|
f"Dataset: {context.ds_name}\nConfig: {context.ds_config_name}\nSplit: {context.ds_split_name}" |
|
) |
|
|
|
st.write("**Data after processing and inference**") |
|
|
|
processed_df = ( |
|
context.df_tokens.drop("hidden_states", axis=1).drop("attention_mask", axis=1).round(3) |
|
) |
|
cols = ( |
|
"ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split() |
|
) |
|
if "token_type_ids" not in processed_df.columns: |
|
cols.remove("token_type_ids") |
|
processed_df = processed_df[cols] |
|
aggrid_interactive_table(processed_df) |
|
processed_df_csv = convert_df(processed_df) |
|
st.download_button( |
|
"Download csv", |
|
processed_df_csv, |
|
"processed_data.csv", |
|
"text/csv", |
|
) |
|
|
|
st.write("**Raw data (exploded by tokens)**") |
|
raw_data_df = context.split.to_pandas().apply(pd.Series.explode) |
|
aggrid_interactive_table(raw_data_df) |
|
raw_data_df_csv = convert_df(raw_data_df) |
|
st.download_button( |
|
"Download csv", |
|
raw_data_df_csv, |
|
"raw_data.csv", |
|
"text/csv", |
|
) |
|
|