GEMRec-Gallery / pages /Summary.py
Ricercar's picture
descriptions update
2d413f9
raw
history blame
8.24 kB
import os
import datasets
import numpy as np
import pandas as pd
import pymysql.cursors
import streamlit as st
from streamlit_elements import elements, mui, html, dashboard, nivo
from streamlit_extras.switch_page_button import switch_page
from streamlit_extras.metric_cards import style_metric_cards
from streamlit_extras.stylable_container import stylable_container
from pages.Gallery import load_hf_dataset
from pages.Ranking import connect_to_db
class DashboardApp:
def __init__(self, roster, promptBook, session_finished):
self.roster = roster
self.promptBook = promptBook
self.session_finished = session_finished
def sidebar(self, tags, mode):
with st.sidebar:
tag = st.selectbox('Select a tag', tags, key='tag')
return tag
def leaderboard(self, tag, db_table):
tag = '%' if tag == 'all' else tag
# get the ranking results of the current user
curser = RANKING_CONN.cursor()
curser.execute(f"SELECT * FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND tag LIKE '{tag}'")
results = curser.fetchall()
curser.close()
modelVersion_standings = self.score_calculator(results, db_table)
# sort the modelVersion_standings by value into a list of tuples in descending order
modelVersion_standings = sorted(modelVersion_standings.items(), key=lambda x: x[1], reverse=True)
tab1, tab2 = st.tabs(['Top Picks', 'Detailed Info'])
with tab1:
# show the top 3 in metric cards
st.write('## Top picks')
n = 3
metric_cols = st.columns(n)
image_display = st.empty()
for i in range(n):
with metric_cols[i]:
modelVersion_id = modelVersion_standings[i][0]
winning_times = modelVersion_standings[i][1]
model_id, model_name, modelVersion_name, url = self.roster[self.roster['modelVersion_id'] == modelVersion_id][['model_id', 'model_name', 'modelVersion_name', 'modelVersion_url']].values[0]
metric_card = stylable_container(
key="container_with_border",
css_styles="""
{
border: 1.5px solid rgba(49, 51, 63, 0.2);
border-left: 0.5rem solid gold;
border-radius: 5px;
padding: calc(1em + 5px);
gap: 0.5em;
box-shadow: 0 0 2rem rgba(0, 0, 0, 0.08);
overflow-x: scroll;
}
""",
)
with metric_card:
icon = '๐Ÿฅ‡'if i == 0 else '๐Ÿฅˆ' if i == 1 else '๐Ÿฅ‰'
# st.write(model_id)
st.write(f'### {icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id})')
st.write(f'Ranking Score: {winning_times}')
show_image = st.button('Show Image', key=modelVersion_id, use_container_width=True)
if show_image:
images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values
with image_display.container():
st.write('---')
st.write(f'### Images generated with {icon} {model_name}, {modelVersion_name}')
col_num = 4
image_cols = st.columns(col_num)
for i in range(len(images)):
with image_cols[i % col_num]:
image = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{images[i]}.png"
st.image(image, use_column_width=True)
# # st.write('---')
# expander = st.expander(f'# {icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{modelVersion_id})')
# with expander:
# images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values
# st.write(images)
st.chat_input('Please leave your comments here.', key='comment')
with tab2:
st.write('## Detailed information of all selected models')
detailed_info = pd.merge(pd.DataFrame(modelVersion_standings, columns=['modelVersion_id', 'ranking_score']), self.roster, on='modelVersion_id')
st.data_editor(detailed_info, hide_index=True, disabled=True)
def score_calculator(self, results, db_table):
# sort results by battle time
results = sorted(results, key=lambda x: x['battletime'])
modelVersion_standings = {}
if db_table == 'battle_results':
for record in results:
modelVersion_standings[record['winner']] = modelVersion_standings.get(record['winner'], 0) + 1
# add the loser who never wins
if record['loser'] not in modelVersion_standings:
modelVersion_standings[record['loser']] = 0
# add the winning time of the loser to the winner
modelVersion_standings[record['winner']] += modelVersion_standings[record['loser']]
elif db_table == 'sort_results':
pts_map = {'position1': 5, 'position2': 3, 'position3': 1, 'position4': 0}
for record in results:
for i in range(1, 5):
modelVersion_standings[record[f'position{i}']] = modelVersion_standings.get(record[f'position{i}'], 0) + pts_map[f'position{i}']
return modelVersion_standings
def app(self):
st.title('Your Preferred Models', help="Scores are calculated based on your ranking results.")
mode = st.sidebar.radio('Ranking mode', ['Sort', 'Battle'], horizontal=True, index=1)
# get tags from database of the current user
db_table = 'sort_results' if mode == 'Sort' else 'battle_results'
tags = ['all']
curser = RANKING_CONN.cursor()
curser.execute(
f"SELECT DISTINCT tag FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}'")
for row in curser.fetchall():
tags.append(row['tag'])
curser.close()
if tags == ['all']:
st.info(f'No rankings are finished with {mode} mode yet.')
else:
tag = self.sidebar(tags, mode)
self.leaderboard(tag, db_table)
# st.chat_input('Please leave your comments here.', key='comment')
if __name__ == "__main__":
st.set_page_config(layout="wide")
if 'user_id' not in st.session_state:
st.warning('Please log in first.')
home_btn = st.button('Go to Home Page')
if home_btn:
switch_page("home")
elif 'progress' not in st.session_state:
st.info('You have not checked any image yet. Please go back to the gallery page and check some images.')
gallery_btn = st.button('๐Ÿ–ผ๏ธ Go to Gallery')
if gallery_btn:
switch_page('gallery')
else:
session_finished = []
for key, value in st.session_state.progress.items():
if value == 'finished':
session_finished.append(key)
if len(session_finished) == 0:
st.info('A dashboard showing your preferred models will appear after you finish any ranking session.')
ranking_btn = st.button('๐ŸŽ–๏ธ Go to Ranking')
if ranking_btn:
switch_page('ranking')
gallery_btn = st.button('๐Ÿ–ผ๏ธ Go to Gallery')
if gallery_btn:
switch_page('gallery')
else:
roster, promptBook, images_ds = load_hf_dataset()
RANKING_CONN = connect_to_db()
app = DashboardApp(roster, promptBook, session_finished)
app.app()