Spaces:
Runtime error
Runtime error
import pandas as pd | |
import streamlit as st | |
from streamlit_option_menu import option_menu | |
from load import load_context | |
from subpages import ( | |
DebugPage, | |
FindDuplicatesPage, | |
HomePage, | |
LossesPage, | |
LossySamplesPage, | |
MetricsPage, | |
MisclassifiedPage, | |
Page, | |
ProbingPage, | |
RandomSamplesPage, | |
RawDataPage, | |
) | |
from subpages.attention import AttentionPage | |
from subpages.hidden_states import HiddenStatesPage | |
from subpages.inspect import InspectPage | |
from utils import classmap | |
sts = st.sidebar | |
st.set_page_config( | |
layout="wide", | |
page_title="Error Analysis", | |
page_icon="🏷️", | |
) | |
def _show_menu(pages: list[Page]) -> int: | |
with st.sidebar: | |
page_names = [p.name for p in pages] | |
page_icons = [p.icon for p in pages] | |
selected_menu_item = st.session_state.active_page = option_menu( | |
menu_title="ExplaiNER", | |
options=page_names, | |
icons=page_icons, | |
menu_icon="layout-wtf", | |
default_index=0, | |
) | |
return page_names.index(selected_menu_item) | |
assert False | |
def _initialize_session_state(pages: list[Page]): | |
if "active_page" not in st.session_state: | |
for page in pages: | |
st.session_state.update(**page.get_widget_defaults()) | |
st.session_state.update(st.session_state) | |
def _write_color_legend(context): | |
def style(x): | |
return [f"background-color: {rgb}; opacity: 1;" for rgb in colors] | |
labels = list(set([lbl.split("-")[1] if "-" in lbl else lbl for lbl in context.labels])) | |
colors = [st.session_state.get(f"color_{lbl}", "#000000") for lbl in labels] | |
color_legend_df = pd.DataFrame( | |
[classmap[l] for l in labels], columns=["label"], index=labels | |
).T | |
st.sidebar.write( | |
color_legend_df.T.style.apply(style, axis=0).set_properties( | |
**{"color": "white", "text-align": "center"} | |
) | |
) | |
def main(): | |
pages: list[Page] = [ | |
HomePage(), | |
AttentionPage(), | |
HiddenStatesPage(), | |
ProbingPage(), | |
MetricsPage(), | |
LossySamplesPage(), | |
LossesPage(), | |
MisclassifiedPage(), | |
RandomSamplesPage(), | |
FindDuplicatesPage(), | |
InspectPage(), | |
RawDataPage(), | |
DebugPage(), | |
] | |
_initialize_session_state(pages) | |
selected_page_idx = _show_menu(pages) | |
selected_page = pages[selected_page_idx] | |
if isinstance(selected_page, HomePage): | |
selected_page.render() | |
return | |
if "model_name" not in st.session_state: | |
# this can happen if someone loads another page directly (without going through home) | |
st.error("Setup not complete. Please click on 'Home / Setup in left menu bar'") | |
return | |
context = load_context(**st.session_state) | |
_write_color_legend(context) | |
selected_page.render(context) | |
if __name__ == "__main__": | |
main() | |