Ricercar commited on
Commit
60aa916
1 Parent(s): bae9cca

append results page with minor promblem

Browse files

results page ui finished
the score calculator for battle mode is in progress

Files changed (2) hide show
  1. pages/Gallery.py +3 -3
  2. pages/Resutls.py +170 -0
pages/Gallery.py CHANGED
@@ -565,9 +565,9 @@ def load_hf_dataset():
565
  # images_ds = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
566
  images_ds = None # set to None for now since we use s3 bucket to store images
567
 
568
- # process dataset
569
- roster = roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
570
- 'model_download_count']].drop_duplicates().reset_index(drop=True)
571
 
572
  # add 'custom_score_weights' column to promptBook if not exist
573
  if 'weighted_score_sum' not in promptBook.columns:
 
565
  # images_ds = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
566
  images_ds = None # set to None for now since we use s3 bucket to store images
567
 
568
+ # # process dataset
569
+ # roster = roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
570
+ # 'model_download_count']].drop_duplicates().reset_index(drop=True)
571
 
572
  # add 'custom_score_weights' column to promptBook if not exist
573
  if 'weighted_score_sum' not in promptBook.columns:
pages/Resutls.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import datasets
4
+ import numpy as np
5
+ import pandas as pd
6
+ import pymysql.cursors
7
+ import streamlit as st
8
+
9
+ from streamlit_elements import elements, mui, html, dashboard, nivo
10
+ from streamlit_extras.switch_page_button import switch_page
11
+ from streamlit_extras.metric_cards import style_metric_cards
12
+ from streamlit_extras.stylable_container import stylable_container
13
+
14
+ from pages.Gallery import load_hf_dataset
15
+ from pages.Ranking import connect_to_db
16
+
17
+
18
+ class DashboardApp:
19
+ def __init__(self, roster, promptBook, session_finished):
20
+ self.roster = roster
21
+ self.promptBook = promptBook
22
+ self.session_finished = session_finished
23
+
24
+ def sidebar(self, tags, mode):
25
+ with st.sidebar:
26
+ tag = st.selectbox('Select a tag', tags, key='tag')
27
+
28
+ return tag
29
+
30
+ def leaderboard(self, tag, db_table):
31
+ tag = '%' if tag == 'all' else tag
32
+
33
+ # get the ranking results of the current user
34
+ curser = RANKING_CONN.cursor()
35
+ curser.execute(f"SELECT * FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND tag LIKE '{tag}'")
36
+ results = curser.fetchall()
37
+ curser.close()
38
+
39
+ modelVersion_standings = self.score_calculator(results, db_table)
40
+
41
+ # sort the modelVersion_standings by value into a list of tuples in descending order
42
+ modelVersion_standings = sorted(modelVersion_standings.items(), key=lambda x: x[1], reverse=True)
43
+ # show the top 3 in metric cards
44
+ st.write('## Top picks')
45
+ n = 3
46
+ metric_cols = st.columns(n)
47
+
48
+ for i in range(n):
49
+ with metric_cols[i]:
50
+ modelVersion_id = modelVersion_standings[i][0]
51
+ winning_times = modelVersion_standings[i][1]
52
+
53
+ print(self.roster)
54
+
55
+ model_name, modelVersion_name, url = self.roster[self.roster['modelVersion_id'] == modelVersion_id][['model_name', 'modelVersion_name', 'modelVersion_url']].values[0]
56
+
57
+ # st.metric(label=str(modelVersion_id) + ' ' + model_name, value=modelVersion_name, delta=f'Ranking Score: {winning_times}', delta_color='off')
58
+ # # st.write(f'https://civitai.com/models/{modelVersion_id}')
59
+ #
60
+ # style_metric_cards(border_left_color='gold')
61
+ # st.button(str(modelVersion_id), on_click=lambda: os.system(f'open https://civitai.com/models/{modelVersion_id}'), key=modelVersion_id, use_container_width=True)
62
+
63
+ metric_card = stylable_container(
64
+ key="container_with_border",
65
+ css_styles="""
66
+ {
67
+ border: 1px solid rgba(49, 51, 63, 0.2);
68
+ border-left: 0.5rem solid silver;
69
+ border-radius: 5px;
70
+ padding: calc(1em + 5px);
71
+ gap: 0.5em;
72
+ box-shadow: 0 0 2rem rgba(0, 0, 0, 0.08);
73
+ overflow: scroll;
74
+ }
75
+ """,
76
+ )
77
+
78
+ with metric_card:
79
+ icon = '🥇'if i == 0 else '🥈' if i == 1 else '🥉'
80
+ st.write(modelVersion_id)
81
+ st.write(f'### {icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{modelVersion_id})')
82
+ st.write(f'Ranking Score: {winning_times}')
83
+
84
+ st.write('---')
85
+
86
+ st.write('## Detailed information of all selected models')
87
+ detailed_info = pd.merge(pd.DataFrame(modelVersion_standings, columns=['modelVersion_id', 'ranking_score']), self.roster, on='modelVersion_id')
88
+ st.data_editor(detailed_info, hide_index=True, disabled=True)
89
+
90
+
91
+ def score_calculator(self, results, db_table):
92
+ modelVersion_standings = {}
93
+ if db_table == 'battle_results':
94
+ for record in results:
95
+ modelVersion_standings[record['winner']] = modelVersion_standings.get(record['winner'], 0) + 1
96
+ # add the winning time of the loser
97
+ curser = RANKING_CONN.cursor()
98
+ curser.execute(f"SELECT COUNT(*) FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND winner = '{record['loser']}'")
99
+ modelVersion_standings[record['winner']] += curser.fetchone()['COUNT(*)']
100
+ curser.close()
101
+
102
+ # add the loser who never wins
103
+ if record['loser'] not in modelVersion_standings:
104
+ modelVersion_standings[record['loser']] = 0
105
+
106
+ elif db_table == 'sort_results':
107
+ pts_map = {'position1': 5, 'position2': 3, 'position3': 1, 'position4': 0}
108
+ for record in results:
109
+ for i in range(1, 5):
110
+ modelVersion_standings[record[f'position{i}']] = modelVersion_standings.get(record[f'position{i}'], 0) + pts_map[f'position{i}']
111
+
112
+ return modelVersion_standings
113
+
114
+
115
+
116
+ def app(self):
117
+ st.title('Your Preferred Models', help="Scores are calculated based on your ranking results.")
118
+
119
+ mode = st.sidebar.radio('Ranking mode', ['Sort', 'Battle'], horizontal=True)
120
+ # get tags from database of the current user
121
+ db_table = 'sort_results' if mode == 'Sort' else 'battle_results'
122
+
123
+ tags = ['all']
124
+ curser = RANKING_CONN.cursor()
125
+ curser.execute(
126
+ f"SELECT DISTINCT tag FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}'")
127
+ for row in curser.fetchall():
128
+ tags.append(row['tag'])
129
+ curser.close()
130
+
131
+ if tags == ['all']:
132
+ st.info(f'No rankings are finished with {mode} mode yet.')
133
+
134
+ else:
135
+ tag = self.sidebar(tags, mode)
136
+ self.leaderboard(tag, db_table)
137
+
138
+
139
+ if __name__ == "__main__":
140
+ st.set_page_config(layout="wide")
141
+
142
+ if 'user_id' not in st.session_state:
143
+ st.warning('Please log in first.')
144
+ home_btn = st.button('Go to Home Page')
145
+ if home_btn:
146
+ switch_page("home")
147
+
148
+ else:
149
+ session_finished = []
150
+
151
+ for key, value in st.session_state.progress.items():
152
+ if value == 'finished':
153
+ session_finished.append(key)
154
+
155
+ if len(session_finished) == 0:
156
+ st.info('A dashboard showing your preferred models will appear after you finish any ranking session.')
157
+ ranking_btn = st.button('🎖️ Back to Ranking')
158
+ if ranking_btn:
159
+ switch_page('ranking')
160
+ gallery_btn = st.button('🖼️ Back to Gallery')
161
+ if gallery_btn:
162
+ switch_page('gallery')
163
+
164
+ else:
165
+ roster, promptBook, images_ds = load_hf_dataset()
166
+ RANKING_CONN = connect_to_db()
167
+ app = DashboardApp(roster, promptBook, session_finished)
168
+ app.app()
169
+
170
+