import pandas as pd import streamlit as st import config from pathlib import Path as P import json nbow_results_path = P("assets").glob("nbow*") def display_metrics_dict(metrics, display_only_accuracy): model_name = metrics.pop("model_name") columns = metrics.pop("columns").split("_") st.markdown(f"### columns: {columns}") st.markdown(f"best model {model_name}") if not display_only_accuracy: st.json(metrics) else: st.json({"accuracy@10": metrics["accuracy@k"]["10"]}) def display_metrics(): display_only_accuracy = st.sidebar.checkbox("display only accuracy@10", value=True) st.markdown("## Test metrics for best validation modelon given columns") for p in nbow_results_path: metrics = json.loads(open(p, "r").read()) display_metrics_dict(metrics, display_only_accuracy) display_metrics() best_results_df = pd.read_csv(config.best_tasks_path) worst_results_df = pd.read_csv(config.worst_tasks_path) show_worst_best_statistics = st.sidebar.checkbox( label="show worst/best statistics grouped by area" ) show_area_aggregated_results = st.sidebar.checkbox( label="show results aggregated by area" ) if show_worst_best_statistics: st.markdown( """ ## Worst/best queries The following are top 10 worst/best queries per area by number of hits. There are at least 10 documents per query in the test set, so number of hits/10 is the accuracy. """ ) sort_key = st.selectbox("sort by", list(best_results_df.columns)) st.markdown("## Queries with best results") st.table(best_results_df.sort_values(sort_key, ascending=False)) st.markdown("## Queries with worst results") st.table(worst_results_df.sort_values(sort_key, ascending=False)) if show_area_aggregated_results: st.markdown("## Area aggregated results") best_results_agg = best_results_df.groupby("area").agg("mean").reset_index() worst_results_agg = worst_results_df.groupby("area").agg("mean").reset_index() sort_key = st.selectbox("sort by", list(best_results_agg.columns)) st.markdown("Best results") st.table(best_results_agg.sort_values(sort_key, ascending=False)) st.markdown("Worst results") st.table(worst_results_agg.sort_values(sort_key, ascending=False))