Spaces:
Running
Running
import os | |
import datasets | |
import numpy as np | |
import pandas as pd | |
import pymysql.cursors | |
import streamlit as st | |
from streamlit_elements import elements, mui, html, dashboard, nivo | |
from streamlit_extras.switch_page_button import switch_page | |
from streamlit_extras.metric_cards import style_metric_cards | |
from streamlit_extras.stylable_container import stylable_container | |
from pages.Gallery import load_hf_dataset | |
from pages.Ranking import connect_to_db | |
class DashboardApp: | |
def __init__(self, roster, promptBook, session_finished): | |
self.roster = roster | |
self.promptBook = promptBook | |
self.session_finished = session_finished | |
def sidebar(self, tags, mode): | |
with st.sidebar: | |
tag = st.selectbox('Select a tag', tags, key='tag') | |
return tag | |
def leaderboard(self, tag, db_table): | |
tag = '%' if tag == 'all' else tag | |
# get the ranking results of the current user | |
curser = RANKING_CONN.cursor() | |
curser.execute(f"SELECT * FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND tag LIKE '{tag}'") | |
results = curser.fetchall() | |
curser.close() | |
modelVersion_standings = self.score_calculator(results, db_table) | |
# sort the modelVersion_standings by value into a list of tuples in descending order | |
modelVersion_standings = sorted(modelVersion_standings.items(), key=lambda x: x[1], reverse=True) | |
tab1, tab2 = st.tabs(['Top Picks', 'Detailed Info']) | |
with tab1: | |
# show the top 3 in metric cards | |
st.write('## Top picks') | |
n = 3 | |
metric_cols = st.columns(n) | |
image_display = st.empty() | |
for i in range(n): | |
with metric_cols[i]: | |
modelVersion_id = modelVersion_standings[i][0] | |
winning_times = modelVersion_standings[i][1] | |
model_id, model_name, modelVersion_name, url = self.roster[self.roster['modelVersion_id'] == modelVersion_id][['model_id', 'model_name', 'modelVersion_name', 'modelVersion_url']].values[0] | |
metric_card = stylable_container( | |
key="container_with_border", | |
css_styles=""" | |
{ | |
border: 1.5px solid rgba(49, 51, 63, 0.2); | |
border-left: 0.5rem solid gold; | |
border-radius: 5px; | |
padding: calc(1em + 5px); | |
gap: 0.5em; | |
box-shadow: 0 0 2rem rgba(0, 0, 0, 0.08); | |
overflow-x: scroll; | |
} | |
""", | |
) | |
with metric_card: | |
icon = '๐ฅ'if i == 0 else '๐ฅ' if i == 1 else '๐ฅ' | |
# st.write(model_id) | |
st.write(f'### {icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id})') | |
st.write(f'Ranking Score: {winning_times}') | |
show_image = st.button('Show Image', key=modelVersion_id, use_container_width=True) | |
if show_image: | |
images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values | |
with image_display.container(): | |
st.write('---') | |
st.write(f'### Images generated with {icon} {model_name}, {modelVersion_name}') | |
col_num = 4 | |
image_cols = st.columns(col_num) | |
for i in range(len(images)): | |
with image_cols[i % col_num]: | |
image = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{images[i]}.png" | |
st.image(image, use_column_width=True) | |
# # st.write('---') | |
# expander = st.expander(f'# {icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{modelVersion_id})') | |
# with expander: | |
# images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values | |
# st.write(images) | |
st.chat_input('Please leave your comments here.', key='comment') | |
with tab2: | |
st.write('## Detailed information of all selected models') | |
detailed_info = pd.merge(pd.DataFrame(modelVersion_standings, columns=['modelVersion_id', 'ranking_score']), self.roster, on='modelVersion_id') | |
st.data_editor(detailed_info, hide_index=True, disabled=True) | |
def score_calculator(self, results, db_table): | |
# sort results by battle time | |
results = sorted(results, key=lambda x: x['battletime']) | |
modelVersion_standings = {} | |
if db_table == 'battle_results': | |
for record in results: | |
modelVersion_standings[record['winner']] = modelVersion_standings.get(record['winner'], 0) + 1 | |
# add the loser who never wins | |
if record['loser'] not in modelVersion_standings: | |
modelVersion_standings[record['loser']] = 0 | |
# add the winning time of the loser to the winner | |
modelVersion_standings[record['winner']] += modelVersion_standings[record['loser']] | |
elif db_table == 'sort_results': | |
pts_map = {'position1': 5, 'position2': 3, 'position3': 1, 'position4': 0} | |
for record in results: | |
for i in range(1, 5): | |
modelVersion_standings[record[f'position{i}']] = modelVersion_standings.get(record[f'position{i}'], 0) + pts_map[f'position{i}'] | |
return modelVersion_standings | |
def app(self): | |
st.title('Your Preferred Models', help="Scores are calculated based on your ranking results.") | |
mode = st.sidebar.radio('Ranking mode', ['Sort', 'Battle'], horizontal=True, index=1) | |
# get tags from database of the current user | |
db_table = 'sort_results' if mode == 'Sort' else 'battle_results' | |
tags = ['all'] | |
curser = RANKING_CONN.cursor() | |
curser.execute( | |
f"SELECT DISTINCT tag FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}'") | |
for row in curser.fetchall(): | |
tags.append(row['tag']) | |
curser.close() | |
if tags == ['all']: | |
st.info(f'No rankings are finished with {mode} mode yet.') | |
else: | |
tag = self.sidebar(tags, mode) | |
self.leaderboard(tag, db_table) | |
# st.chat_input('Please leave your comments here.', key='comment') | |
if __name__ == "__main__": | |
st.set_page_config(layout="wide") | |
if 'user_id' not in st.session_state: | |
st.warning('Please log in first.') | |
home_btn = st.button('Go to Home Page') | |
if home_btn: | |
switch_page("home") | |
elif 'progress' not in st.session_state: | |
st.info('You have not checked any image yet. Please go back to the gallery page and check some images.') | |
gallery_btn = st.button('๐ผ๏ธ Go to Gallery') | |
if gallery_btn: | |
switch_page('gallery') | |
else: | |
session_finished = [] | |
for key, value in st.session_state.progress.items(): | |
if value == 'finished': | |
session_finished.append(key) | |
if len(session_finished) == 0: | |
st.info('A dashboard showing your preferred models will appear after you finish any ranking session.') | |
ranking_btn = st.button('๐๏ธ Go to Ranking') | |
if ranking_btn: | |
switch_page('ranking') | |
gallery_btn = st.button('๐ผ๏ธ Go to Gallery') | |
if gallery_btn: | |
switch_page('gallery') | |
else: | |
roster, promptBook, images_ds = load_hf_dataset() | |
RANKING_CONN = connect_to_db() | |
app = DashboardApp(roster, promptBook, session_finished) | |
app.app() | |