news_classification_model_monitor / classification_model_monitor.py
ksvmuralidhar's picture
Update classification_model_monitor.py
9d973fa verified
raw
history blame
11 kB
import streamlit as st
import pandas as pd
import numpy as np
import seaborn as sns
import plotly.express as px
import matplotlib.pyplot as plt
from read_predictions_from_db import PredictionDBRead
from read_daily_metrics_from_db import MetricsDBRead
from sklearn.metrics import balanced_accuracy_score, accuracy_score
import logging
from config import (CLASSIFIER_ADJUSTMENT_THRESHOLD,
PERFORMANCE_THRESHOLD,
CLASSIFIER_THRESHOLD)
logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO)
def filter_prediction_data(data: pd.DataFrame):
try:
logging.info("Entering filter_prediction_data()")
if data is None:
raise Exception("Input Prediction Data frame in None")
# filtered_prediction_data = data.loc[(data['y_true'].isin(['WEATHER', 'EDUCATION', 'ASTROLOGY', 'OTHERS']) == False) &
# (data['y_pred'].isin(['WEATHER', 'EDUCATION', 'ASTROLOGY', 'OTHERS']) == False) &
# (data['y_true_proba'] > CLASSIFIER_THRESHOLD)].copy()
filtered_prediction_data = data.loc[data['y_true_proba'] == 1].copy()
logging.info("Exiting filter_prediction_data()")
return filtered_prediction_data
except Exception as e:
logging.critical(f"Error in filter_prediction_data(): {e}")
return None
def get_adjusted_predictions(df):
try:
logging.info("Entering get_adjusted_predictions()")
if df is None:
raise Exception('Input Filtered Prediction Data Frame is None')
df = df.copy()
df.reset_index(drop=True, inplace=True)
df.loc[df['y_pred_proba']<CLASSIFIER_ADJUSTMENT_THRESHOLD, 'y_pred'] = 'NATION'
df.loc[(df['text'].str.contains('Pakistan')) & (df['y_pred'] == 'NATION'), 'y_pred'] = 'WORLD'
df.loc[(df['text'].str.contains('Zodiac Sign', case=False)) | (df['text'].str.contains('Horoscope', case=False)), 'y_pred'] = 'SCIENCE'
logging.info("Exiting get_adjusted_predictions()")
return df
except Exception as e:
logging.info(f"Error in get_adjusted_predictions(): {e}")
return None
def display_kpis(data: pd.DataFrame, adj_data: pd.DataFrame):
try:
logging.info("Entering display_kpis()")
if data is None:
raise Exception("Input Prediction Data frame in None")
if adj_data is None:
raise Exception('Input Adjusted Data frame is None')
n_samples = len(data)
balanced_accuracy = np.round(balanced_accuracy_score(data['y_true'], data['y_pred']), 4)
accuracy = np.round(accuracy_score(data['y_true'], data['y_pred']), 4)
adj_balanced_accuracy = np.round(balanced_accuracy_score(adj_data['y_true'], adj_data['y_pred']), 4)
adj_accuracy = np.round(accuracy_score(adj_data['y_true'], adj_data['y_pred']), 4)
st.write('''<style>
[data-testid="column"] {
width: calc(33.3333% - 1rem) !important;
flex: 1 1 calc(33.3333% - 1rem) !important;
min-width: calc(33% - 1rem) !important;
}
</style>''',
unsafe_allow_html=True)
col1, col2= st.columns(2)
with col1:
metric1 = st.metric(label="Balanced Accuracy", value=balanced_accuracy)
with col2:
metric2 = st.metric(label="Adj Balanced Accuracy", value=adj_balanced_accuracy)
col3, col4= st.columns(2)
with col3:
metric3 = st.metric(label="Accuracy", value=accuracy)
with col4:
metric4 = st.metric(label="Adj Accuracy", value=adj_accuracy)
col5, col6= st.columns(2)
with col5:
metric5 = st.metric(label="Bal Accuracy Threshold", value=PERFORMANCE_THRESHOLD)
with col6:
metric6 = st.metric(label="N Samples", value=n_samples)
logging.info("Exiting display_kpis()")
except Exception as e:
logging.critical(f'Error in display_kpis(): {e}')
st.error("Couldn't display KPIs")
def plot_daily_metrics(metrics_df: pd.DataFrame):
try:
logging.info("Entering plot_daily_metrics()")
st.write(" ")
if metrics_df is None:
raise Exception('Input Metrics Data Frame is None')
metrics_df['evaluation_date'] = pd.to_datetime(metrics_df['evaluation_date'])
metrics_df['mean_score_minus_std'] = np.round(metrics_df['mean_balanced_accuracy_score'] - metrics_df['std_balanced_accuracy_score'], 4)
metrics_df['mean_score_plus_std'] = np.round(metrics_df['mean_balanced_accuracy_score'] + metrics_df['std_balanced_accuracy_score'], 4)
hover_data={'mean_balanced_accuracy_score': True,
'std_balanced_accuracy_score': False,
'mean_score_minus_std': True,
'mean_score_plus_std': True,
'evaluation_window_days': True,
'n_splits': True,
'sample_start_date': True,
'sample_end_date': True,
'sample_size_of_each_split': True}
hover_labels = {'mean_balanced_accuracy_score': "Mean Score",
'mean_score_minus_std': "Mean Score - Stdev",
'mean_score_plus_std': "Mean Score + Stdev",
'evaluation_window_days': "Observation Window (Days)",
'sample_start_date': "Observation Window Start Date",
'sample_end_date': "Observation Window End Date",
'n_splits': "N Splits For Evaluation",
'sample_size_of_each_split': "Sample Size of Each Split"}
fig = px.line(data_frame=metrics_df, x='evaluation_date',
y='mean_balanced_accuracy_score',
error_y='std_balanced_accuracy_score',
title="Daily Balanced Accuracy",
color_discrete_sequence=['black'],
hover_data=hover_data, labels=hover_labels, markers=True)
fig.add_hline(y=PERFORMANCE_THRESHOLD, line_dash="dash", line_color="green",
annotation_text=f"<b>THRESHOLD</b>",
annotation_position="left top")
fig.update_layout(dragmode='pan')
fig.update_layout(margin=dict(l=0, r=0, t=110, b=10))
st.plotly_chart(fig, use_container_width=True)
logging.info("Exiting plot_daily_metrics()")
except Exception as e:
logging.critical(f'Error in plot_daily_metrics(): {e}')
st.error("Couldn't Plot Daily Model Metrics")
def get_misclassified_classes(data):
try:
logging.info("Entering get_misclassified_classes()")
if data is None:
raise Exception("Input Prediction Data Frame is None")
data = data.copy()
data['match'] = (data['y_true'] == data['y_pred']).astype('int')
y_pred_counts = data['y_pred'].value_counts()
misclassified_examples = data.loc[data['match'] == 0, ['text', 'y_true', 'y_pred', 'y_pred_proba', 'url']].copy()
misclassified_examples.sort_values(by=['y_pred', 'y_pred_proba'], ascending=[True, False], inplace=True)
misclassifications = data.loc[data['match'] == 0, 'y_pred'].value_counts()[y_pred_counts.index]
misclassifications /= y_pred_counts
misclassifications.sort_values(ascending=False, inplace=True)
logging.info("Exiting get_misclassified_classes()")
return np.round(misclassifications, 2), misclassified_examples
except Exception as e:
logging.critical(f'Error in get_misclassified_classes(): {e}')
return None, None
def display_misclassified_examples(misclassified_classes, misclassified_examples):
try:
logging.info("Entering display_misclassified_examples()")
st.write(" ")
if misclassified_classes is None:
raise Exception('Misclassified Classes Distribution Data Frame is None')
if misclassified_examples is None:
raise Exception('Misclassified Examples Data Frame is None')
fig, ax = plt.subplots(figsize=(10, 4.5))
misclassified_classes.plot(kind='bar', ax=ax, color='black', title="Misclassification percentage")
plt.yticks([])
plt.xlabel("")
ax.bar_label(ax.containers[0]);
st.pyplot(fig)
st.markdown("<b>Misclassified examples</b>", unsafe_allow_html=True)
st.dataframe(misclassified_examples, hide_index=True)
st.markdown(
"""
<style>
[data-testid="stElementToolbar"] {
display: none;
}
</style>
""",
unsafe_allow_html=True
)
logging.info("Exiting display_misclassified_examples()")
except Exception as e:
logging.critical(f'Error in display_misclassified_examples(): {e}')
st.error("Couldn't display Misclassification Data")
def classification_model_monitor():
try:
st.write("<h4>Classification Model Monitor</h4>", unsafe_allow_html=True)
prediction_db = PredictionDBRead()
metrics_db = MetricsDBRead()
# Read Prediction Data From DB
prediction_data = prediction_db.read_predictions_from_db()
# Filter Prediction Data
filtered_prediction_data = filter_prediction_data(prediction_data)
# Get Adjusted Prediction Data
adjusted_filtered_prediction_data = get_adjusted_predictions(filtered_prediction_data)
# Display KPIs
display_kpis(filtered_prediction_data, adjusted_filtered_prediction_data)
# Read Daily Metrics From DB
metrics_df = metrics_db.read_metrics_from_db()
# Display daily Metrics Line Plot
plot_daily_metrics(metrics_df)
# Get misclassified class distribution and misclassified examples from Prediction Data
misclassified_classes, misclassified_examples = get_misclassified_classes(filtered_prediction_data)
# Display Misclassification Data
display_misclassified_examples(misclassified_classes, misclassified_examples)
st.markdown(
"""<style>
[data-testid="stMetricValue"] {
font-size: 25px;
}
</style>
""", unsafe_allow_html=True
)
except Exception as e:
logging.critical(f"Error in classification_model_monitor(): {e}")
st.error("Unexpected Error. Couldn't display Classification Model Monitor")