|
"""Show count, mean and median loss per token and label.""" |
|
import streamlit as st |
|
|
|
from src.subpages.page import Context, Page |
|
from src.utils import AgGrid, aggrid_interactive_table |
|
|
|
|
|
@st.cache |
|
def get_loss_by_token(df_tokens): |
|
return ( |
|
df_tokens.groupby("tokens")[["losses"]] |
|
.agg(["count", "mean", "median", "sum"]) |
|
.droplevel(level=0, axis=1) |
|
.sort_values(by="sum", ascending=False) |
|
.reset_index() |
|
) |
|
|
|
|
|
@st.cache |
|
def get_loss_by_label(df_tokens): |
|
return ( |
|
df_tokens.groupby("labels")[["losses"]] |
|
.agg(["count", "mean", "median", "sum"]) |
|
.droplevel(level=0, axis=1) |
|
.sort_values(by="mean", ascending=False) |
|
.reset_index() |
|
) |
|
|
|
|
|
class LossesPage(Page): |
|
name = "Loss by Token/Label" |
|
icon = "sort-alpha-down" |
|
|
|
def render(self, context: Context): |
|
st.title(self.name) |
|
with st.expander("💡", expanded=True): |
|
st.write("Show count, mean and median loss per token and label.") |
|
st.write( |
|
"Look out for tokens that have a big gap between mean and median, indicating systematic labeling issues." |
|
) |
|
|
|
col1, _, col2 = st.columns([8, 1, 6]) |
|
|
|
with col1: |
|
st.subheader("💬 Loss by Token") |
|
|
|
st.session_state["_merge_tokens"] = st.checkbox( |
|
"Merge tokens", value=True, key="merge_tokens" |
|
) |
|
loss_by_token = ( |
|
get_loss_by_token(context.df_tokens_merged) |
|
if st.session_state["merge_tokens"] |
|
else get_loss_by_token(context.df_tokens_cleaned) |
|
) |
|
aggrid_interactive_table(loss_by_token.round(3)) |
|
|
|
|
|
|
|
|
|
st.write( |
|
"_Caveat: Even though tokens have contextual representations, we average them to get these summary statistics._" |
|
) |
|
|
|
with col2: |
|
st.subheader("🏷️ Loss by Label") |
|
loss_by_label = get_loss_by_label(context.df_tokens_cleaned) |
|
AgGrid(loss_by_label.round(3), height=200) |
|
|