model-usage / app.py
nazneen's picture
modular code and more features
f37cf2e
raw
history blame
3.19 kB
## LIBRARIES ###
from cProfile import label
from tkinter import font
from turtle import width
import streamlit as st
import pandas as pd
from datetime import datetime
import plotly.express as px
def select_plot_data(df, quantile_low, qunatile_high):
df.fillna(0, inplace=True)
df_plot = df.set_index('Model').T
df_plot.index = date_range(df_plot)
df_stats = df_plot.describe()
quantile_lvalue = df_stats.quantile(quantile_low, axis=1)['mean']
quantile_hvalue = df_stats.quantile(qunatile_high, axis=1)['mean']
df_plot_data = df_plot.loc[:,[(df_plot[col].mean() > quantile_lvalue and df_plot[col].mean() < quantile_hvalue) for col in df_plot.columns]]
return df_plot_data
def read_file_to_df(file):
return pd.read_csv(file)
def date_range(df):
time = df.index.to_list()
time_range = []
for t in time:
time_range.append(str(datetime.strptime(t, '%Y-%m-%dT%H:%M:%S.%fZ').date().month) +'/' + str(datetime.strptime(t, '%Y-%m-%dT%H:%M:%S.%fZ').date().day) + '/' + str(datetime.strptime(t, '%Y-%m-%dT%H:%M:%S.%fZ').date().year)[-2:])
return time_range
if __name__ == "__main__":
### STREAMLIT APP CONGFIG ###
st.set_page_config(layout="wide", page_title="HF Hub Model Usage Visualization")
st.header("Model Usage Visualization")
with st.expander("How to read and interact with the plot:"):
st.markdown("The plots below visualize weekly usage for HF models categorized by the model creation time.")
st.markdown("Select the model creation time range you want to visualize using the dropdown menu below.")
st.markdown("Choose the quantile range to filter out models with high or low usage.")
st.markdown("The plots are interactive. Hover over the points to see the model name and the number of weekly mean usage. Click on the legend to hide/show the models.")
model_init_year = st.multiselect("Model creation year", ["before_2021", "2021", "2022"], key = "model_init_year", default = "2022")
popularity_low = st.slider("Model popularity quantile (lower limit) ", min_value=0.0, max_value=1.0, step=0.01, value=0.90, key = "popularity_low")
popularity_high = st.slider("Model popularity quantile (upper limit) ", min_value=0.0, max_value=1.0, step=0.01, value=0.99, key = "popularity_high")
if 'model_init_year' not in st.session_state:
st.session_state['model_init_year'] = model_init_year
if 'popularity_low' not in st.session_state:
st.session_state['popularity_low'] = popularity_low
if 'popularity_high' not in st.session_state:
st.session_state['popularity_high'] = popularity_high
with st.container():
for year in st.session_state['model_init_year']:
plotly_spot = st.empty()
df = read_file_to_df("./assets/"+year+"/model_usage.csv")
df_plot_data = select_plot_data(df, st.session_state['popularity_low'], st.session_state['popularity_high'])
fig = px.line(df_plot_data, title="Models created in "+year, labels={"index": "Weeks", "value": "Usage", "variable": "Model"})
with plotly_spot:
st.plotly_chart(fig, use_container_width=True)