|
"""The App module is the main entry point for the application. |
|
|
|
Run `streamlit run app.py` to start the app. |
|
""" |
|
|
|
import pandas as pd |
|
import streamlit as st |
|
from streamlit_option_menu import option_menu |
|
|
|
from src.load import load_context |
|
from src.subpages import ( |
|
DebugPage, |
|
FindDuplicatesPage, |
|
HomePage, |
|
LossesPage, |
|
LossySamplesPage, |
|
MetricsPage, |
|
MisclassifiedPage, |
|
Page, |
|
ProbingPage, |
|
RandomSamplesPage, |
|
RawDataPage, |
|
) |
|
from src.subpages.attention import AttentionPage |
|
from src.subpages.hidden_states import HiddenStatesPage |
|
from src.subpages.inspect import InspectPage |
|
from src.utils import classmap |
|
|
|
sts = st.sidebar |
|
st.set_page_config( |
|
layout="wide", |
|
page_title="Error Analysis", |
|
page_icon="🏷️", |
|
) |
|
|
|
|
|
def _show_menu(pages: list[Page]) -> int: |
|
with st.sidebar: |
|
page_names = [p.name for p in pages] |
|
page_icons = [p.icon for p in pages] |
|
selected_menu_item = st.session_state.active_page = option_menu( |
|
menu_title="ExplaiNER", |
|
options=page_names, |
|
icons=page_icons, |
|
menu_icon="layout-wtf", |
|
default_index=0, |
|
) |
|
return page_names.index(selected_menu_item) |
|
assert False |
|
|
|
|
|
def _initialize_session_state(pages: list[Page]): |
|
if "active_page" not in st.session_state: |
|
for page in pages: |
|
st.session_state.update(**page._get_widget_defaults()) |
|
st.session_state.update(st.session_state) |
|
|
|
|
|
def _write_color_legend(context): |
|
def style(x): |
|
return [f"background-color: {rgb}; opacity: 1;" for rgb in colors] |
|
|
|
labels = list(set([lbl.split("-")[1] if "-" in lbl else lbl for lbl in context.labels])) |
|
colors = [st.session_state.get(f"color_{lbl}", "#000000") for lbl in labels] |
|
|
|
color_legend_df = pd.DataFrame( |
|
[classmap[l] for l in labels], columns=["label"], index=labels |
|
).T |
|
st.sidebar.write( |
|
color_legend_df.T.style.apply(style, axis=0).set_properties( |
|
**{"color": "white", "text-align": "center"} |
|
) |
|
) |
|
|
|
|
|
def main(): |
|
"""The main entry point for the application.""" |
|
pages: list[Page] = [ |
|
HomePage(), |
|
AttentionPage(), |
|
HiddenStatesPage(), |
|
ProbingPage(), |
|
MetricsPage(), |
|
LossySamplesPage(), |
|
LossesPage(), |
|
MisclassifiedPage(), |
|
RandomSamplesPage(), |
|
FindDuplicatesPage(), |
|
InspectPage(), |
|
RawDataPage(), |
|
DebugPage(), |
|
] |
|
|
|
_initialize_session_state(pages) |
|
|
|
selected_page_idx = _show_menu(pages) |
|
selected_page = pages[selected_page_idx] |
|
|
|
if isinstance(selected_page, HomePage): |
|
selected_page.render() |
|
return |
|
|
|
if "model_name" not in st.session_state: |
|
|
|
st.error("Setup not complete. Please click on 'Home / Setup in left menu bar'") |
|
return |
|
|
|
context = load_context(**st.session_state) |
|
_write_color_legend(context) |
|
selected_page.render(context) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|