Spaces:
Sleeping
Sleeping
fully connected to database
Browse filesadd a counter called 'epoch'
add checked indicator in ranking page
better button logic
- Home.py +27 -5
- pages/Gallery.py +72 -32
- pages/Ranking.py +120 -55
- pages/Summary.py +74 -53
Home.py
CHANGED
@@ -1,7 +1,10 @@
|
|
1 |
-
|
2 |
-
import streamlit as st
|
3 |
import random
|
4 |
-
|
|
|
|
|
|
|
|
|
5 |
from streamlit_extras.switch_page_button import switch_page
|
6 |
|
7 |
|
@@ -31,27 +34,32 @@ def login():
|
|
31 |
|
32 |
|
33 |
def save_user_id(user_id):
|
|
|
34 |
print(user_id)
|
35 |
if not user_id:
|
36 |
user_id = 'anonymous' + str(random.randint(0, 100000))
|
37 |
st.session_state.user_id = [user_id, datetime.now().strftime("%Y-%m-%d %H:%M:%S")]
|
38 |
st.session_state.assigned_rank_mode = random.choice(['Drag and Sort', 'Battle'])
|
|
|
39 |
|
40 |
|
41 |
def logout():
|
42 |
st.session_state.pop('user_id', None)
|
43 |
st.session_state.pop('selected_dict', None)
|
|
|
44 |
st.session_state.pop('score_weights', None)
|
45 |
st.session_state.pop('gallery_state', None)
|
46 |
st.session_state.pop('edit_state', None)
|
47 |
st.session_state.pop('progress', None)
|
|
|
|
|
48 |
st.session_state.pop('gallery_focus', None)
|
49 |
st.session_state.pop('assigned_rank_mode', None)
|
50 |
st.session_state.pop('show_NSFW', None)
|
51 |
st.session_state.pop('modelVersion_standings', None)
|
52 |
|
53 |
|
54 |
-
def
|
55 |
with st.sidebar:
|
56 |
st.write('## About')
|
57 |
st.write(
|
@@ -63,10 +71,24 @@ def info():
|
|
63 |
)
|
64 |
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
if __name__ == '__main__':
|
67 |
# print(st.source_util.get_pages('Home.py'))
|
68 |
st.set_page_config(page_title="Login", page_icon="🏠", layout="wide")
|
69 |
-
|
70 |
st.write('A Research by [MAPS Lab](https://whongyi.github.io/MAPS-research), [NYU Shanghai](https://shanghai.nyu.edu)')
|
71 |
st.title("🙌 Welcome to GEMRec Gallery!")
|
72 |
|
|
|
1 |
+
import os
|
|
|
2 |
import random
|
3 |
+
|
4 |
+
import pymysql.cursors
|
5 |
+
import streamlit as st
|
6 |
+
|
7 |
+
from datetime import datetime
|
8 |
from streamlit_extras.switch_page_button import switch_page
|
9 |
|
10 |
|
|
|
34 |
|
35 |
|
36 |
def save_user_id(user_id):
|
37 |
+
user_id = user_id[:60]
|
38 |
print(user_id)
|
39 |
if not user_id:
|
40 |
user_id = 'anonymous' + str(random.randint(0, 100000))
|
41 |
st.session_state.user_id = [user_id, datetime.now().strftime("%Y-%m-%d %H:%M:%S")]
|
42 |
st.session_state.assigned_rank_mode = random.choice(['Drag and Sort', 'Battle'])
|
43 |
+
st.session_state.epoch = {'gallery': 0, 'ranking': {}, 'summary': {'overall': 0}}
|
44 |
|
45 |
|
46 |
def logout():
|
47 |
st.session_state.pop('user_id', None)
|
48 |
st.session_state.pop('selected_dict', None)
|
49 |
+
st.session_state.pop('epoch', None)
|
50 |
st.session_state.pop('score_weights', None)
|
51 |
st.session_state.pop('gallery_state', None)
|
52 |
st.session_state.pop('edit_state', None)
|
53 |
st.session_state.pop('progress', None)
|
54 |
+
st.session_state.pop('pointer', None)
|
55 |
+
st.session_state.pop('counter', None)
|
56 |
st.session_state.pop('gallery_focus', None)
|
57 |
st.session_state.pop('assigned_rank_mode', None)
|
58 |
st.session_state.pop('show_NSFW', None)
|
59 |
st.session_state.pop('modelVersion_standings', None)
|
60 |
|
61 |
|
62 |
+
def project_info():
|
63 |
with st.sidebar:
|
64 |
st.write('## About')
|
65 |
st.write(
|
|
|
71 |
)
|
72 |
|
73 |
|
74 |
+
def connect_to_db():
|
75 |
+
conn = pymysql.connect(
|
76 |
+
host=os.environ.get('RANKING_DB_HOST'),
|
77 |
+
port=3306,
|
78 |
+
database='myRanking',
|
79 |
+
user=os.environ.get('RANKING_DB_USER'),
|
80 |
+
password=os.environ.get('RANKING_DB_PASSWORD'),
|
81 |
+
charset='utf8mb4',
|
82 |
+
cursorclass=pymysql.cursors.DictCursor
|
83 |
+
)
|
84 |
+
|
85 |
+
return conn
|
86 |
+
|
87 |
+
|
88 |
if __name__ == '__main__':
|
89 |
# print(st.source_util.get_pages('Home.py'))
|
90 |
st.set_page_config(page_title="Login", page_icon="🏠", layout="wide")
|
91 |
+
project_info()
|
92 |
st.write('A Research by [MAPS Lab](https://whongyi.github.io/MAPS-research), [NYU Shanghai](https://shanghai.nyu.edu)')
|
93 |
st.title("🙌 Welcome to GEMRec Gallery!")
|
94 |
|
pages/Gallery.py
CHANGED
@@ -12,6 +12,7 @@ import streamlit.components.v1 as components
|
|
12 |
|
13 |
from bs4 import BeautifulSoup
|
14 |
from datasets import load_dataset, Dataset, load_from_disk
|
|
|
15 |
from huggingface_hub import login
|
16 |
from streamlit_agraph import agraph, Node, Edge, Config
|
17 |
from streamlit_extras.switch_page_button import switch_page
|
@@ -19,6 +20,7 @@ from streamlit_extras.tags import tagger_component
|
|
19 |
from streamlit_extras.no_default_selectbox import selectbox
|
20 |
from sklearn.svm import LinearSVC
|
21 |
|
|
|
22 |
|
23 |
class GalleryApp:
|
24 |
def __init__(self, promptBook, images_ds):
|
@@ -203,7 +205,7 @@ class GalleryApp:
|
|
203 |
|
204 |
# remove coloring from tag
|
205 |
tag = self.text_coloring_remove(tag)
|
206 |
-
print('tag: ', tag)
|
207 |
|
208 |
# print('current state: ', st.session_state.gallery_state)
|
209 |
|
@@ -237,7 +239,7 @@ class GalleryApp:
|
|
237 |
|
238 |
# remove coloring from prompt
|
239 |
selected_prompt = self.text_coloring_remove(selected_prompt)
|
240 |
-
print('selected_prompt: ', selected_prompt)
|
241 |
st.session_state.prompt_idx_last_time = prompts.index(selected_prompt) if selected_prompt else 0
|
242 |
|
243 |
if selected_prompt is None:
|
@@ -282,15 +284,15 @@ class GalleryApp:
|
|
282 |
pass
|
283 |
|
284 |
if has_selection:
|
285 |
-
checkout = st.button('Check out selections ➡️', use_container_width=True, type='primary')
|
286 |
-
if checkout:
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
else:
|
295 |
st.button(':orange[👇 **Select images you like below**]', disabled=True, use_container_width=True)
|
296 |
try:
|
@@ -308,6 +310,12 @@ class GalleryApp:
|
|
308 |
|
309 |
self.checkout_mode(tag, items)
|
310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
|
312 |
def random_gallery_focus(self, tags):
|
313 |
st.session_state.gallery_focus['tag'] = random.choice(tags)
|
@@ -315,7 +323,6 @@ class GalleryApp:
|
|
315 |
prompts = self.promptBook[self.promptBook['tag'] == st.session_state.gallery_focus['tag']]['prompt'].unique()
|
316 |
st.session_state.gallery_focus['prompt'] = random.choice(prompts)
|
317 |
|
318 |
-
|
319 |
def graph_mode(self, prompt_id, items):
|
320 |
graph_cols = st.columns([3, 1])
|
321 |
# prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}",
|
@@ -353,24 +360,11 @@ class GalleryApp:
|
|
353 |
|
354 |
if checked:
|
355 |
# deselect = st.button('Deselect', key=f'select_{item["prompt_id"]}_{item["modelVersion_id"]}', use_container_width=True)
|
356 |
-
deselect = st.form_submit_button('Deselect', use_container_width=True)
|
357 |
-
if deselect:
|
358 |
-
st.session_state.selected_dict[item['prompt_id']].remove(item['modelVersion_id'])
|
359 |
-
self.remove_ranking_states(item['prompt_id'])
|
360 |
-
st.experimental_rerun()
|
361 |
|
362 |
else:
|
363 |
# select = st.button('Select', key=f'select_{item["prompt_id"]}_{item["modelVersion_id"]}', use_container_width=True, type='primary')
|
364 |
-
select = st.form_submit_button('Select', use_container_width=True, type='primary')
|
365 |
-
if select:
|
366 |
-
st.session_state.selected_dict[item['prompt_id']].append(item['modelVersion_id'])
|
367 |
-
self.remove_ranking_states(item['prompt_id'])
|
368 |
-
|
369 |
-
# add focus to session state
|
370 |
-
st.session_state.gallery_focus['tag'] = item['tag']
|
371 |
-
st.session_state.gallery_focus['prompt'] = item['prompt']
|
372 |
-
|
373 |
-
st.experimental_rerun()
|
374 |
|
375 |
# st.write(item)
|
376 |
infos = ['model_name', 'modelVersion_name', 'model_download_count', 'clip_score', 'mcos_score',
|
@@ -384,6 +378,17 @@ class GalleryApp:
|
|
384 |
else:
|
385 |
st.info('Please click on an image to show')
|
386 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
def checkout_mode(self, tag, items):
|
388 |
# st.write(items)
|
389 |
if len(items) > 0:
|
@@ -407,7 +412,7 @@ class GalleryApp:
|
|
407 |
['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id',
|
408 |
'total_score', 'model_download_count', 'clip_score', 'mcos_score',
|
409 |
'norm_nsfw'],
|
410 |
-
label_visibility='collapsed', key=f'info_{prompt_id}', placeholder='Select what
|
411 |
|
412 |
with checkout_panel[-1]:
|
413 |
checkout_buttons = st.columns([1, 1, 1])
|
@@ -418,7 +423,7 @@ class GalleryApp:
|
|
418 |
st.session_state.gallery_focus['prompt'] = prompt
|
419 |
print(st.session_state.gallery_focus)
|
420 |
st.session_state.gallery_state = 'graph'
|
421 |
-
st.
|
422 |
|
423 |
with checkout_buttons[1]:
|
424 |
# init edit state
|
@@ -429,7 +434,7 @@ class GalleryApp:
|
|
429 |
edit = st.button('Edit', key=f'checkout_edit_{prompt_id}', use_container_width=True)
|
430 |
if edit:
|
431 |
st.session_state.edit_state = True
|
432 |
-
st.
|
433 |
else:
|
434 |
done = st.button('Done', key=f'checkout_done_{prompt_id}', use_container_width=True)
|
435 |
if done:
|
@@ -443,7 +448,7 @@ class GalleryApp:
|
|
443 |
st.session_state.selected_dict[prompt_id].append(int(keys[2]))
|
444 |
|
445 |
st.session_state.edit_state = False
|
446 |
-
st.
|
447 |
|
448 |
with checkout_buttons[-1]:
|
449 |
proceed = st.button('Proceed ➡️', key=f'checkout_proceed_{prompt_id}', use_container_width=True,
|
@@ -452,6 +457,40 @@ class GalleryApp:
|
|
452 |
st.session_state.gallery_focus['tag'] = tag
|
453 |
st.session_state.gallery_focus['prompt'] = prompt
|
454 |
st.session_state.gallery_state = 'graph'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
455 |
switch_page('ranking')
|
456 |
|
457 |
self.gallery_standard(items[items['prompt_id'] == prompt_id].reset_index(drop=True), 4, info, show_checkbox=st.session_state.edit_state)
|
@@ -463,7 +502,7 @@ class GalleryApp:
|
|
463 |
st.session_state.gallery_focus['tag'] = tag
|
464 |
st.session_state.gallery_focus['prompt'] = None
|
465 |
st.session_state.gallery_state = 'graph'
|
466 |
-
st.
|
467 |
|
468 |
def remove_ranking_states(self, prompt_id):
|
469 |
# for drag sort
|
@@ -547,6 +586,7 @@ if __name__ == "__main__":
|
|
547 |
if home_btn:
|
548 |
switch_page("home")
|
549 |
else:
|
|
|
550 |
roster, promptBook, images_ds = load_hf_dataset(st.session_state.show_NSFW)
|
551 |
|
552 |
app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
|
|
|
12 |
|
13 |
from bs4 import BeautifulSoup
|
14 |
from datasets import load_dataset, Dataset, load_from_disk
|
15 |
+
from datetime import datetime
|
16 |
from huggingface_hub import login
|
17 |
from streamlit_agraph import agraph, Node, Edge, Config
|
18 |
from streamlit_extras.switch_page_button import switch_page
|
|
|
20 |
from streamlit_extras.no_default_selectbox import selectbox
|
21 |
from sklearn.svm import LinearSVC
|
22 |
|
23 |
+
from Home import connect_to_db
|
24 |
|
25 |
class GalleryApp:
|
26 |
def __init__(self, promptBook, images_ds):
|
|
|
205 |
|
206 |
# remove coloring from tag
|
207 |
tag = self.text_coloring_remove(tag)
|
208 |
+
# print('tag: ', tag)
|
209 |
|
210 |
# print('current state: ', st.session_state.gallery_state)
|
211 |
|
|
|
239 |
|
240 |
# remove coloring from prompt
|
241 |
selected_prompt = self.text_coloring_remove(selected_prompt)
|
242 |
+
# print('selected_prompt: ', selected_prompt)
|
243 |
st.session_state.prompt_idx_last_time = prompts.index(selected_prompt) if selected_prompt else 0
|
244 |
|
245 |
if selected_prompt is None:
|
|
|
284 |
pass
|
285 |
|
286 |
if has_selection:
|
287 |
+
checkout = st.button('Check out selections ➡️', use_container_width=True, type='primary', on_click=self.switch_to_checkout, args=(tag, selected_prompt))
|
288 |
+
# if checkout:
|
289 |
+
# # add focus to session state
|
290 |
+
# st.session_state.gallery_focus['tag'] = tag
|
291 |
+
# st.session_state.gallery_focus['prompt'] = selected_prompt
|
292 |
+
#
|
293 |
+
# st.session_state.gallery_state = 'check out'
|
294 |
+
# # print(st.session_state.gallery_state)
|
295 |
+
# st.rerun()
|
296 |
else:
|
297 |
st.button(':orange[👇 **Select images you like below**]', disabled=True, use_container_width=True)
|
298 |
try:
|
|
|
310 |
|
311 |
self.checkout_mode(tag, items)
|
312 |
|
313 |
+
def switch_to_checkout(self, tag, selected_prompt):
|
314 |
+
# add focus to session state
|
315 |
+
st.session_state.gallery_focus['tag'] = tag
|
316 |
+
st.session_state.gallery_focus['prompt'] = selected_prompt
|
317 |
+
|
318 |
+
st.session_state.gallery_state = 'check out'
|
319 |
|
320 |
def random_gallery_focus(self, tags):
|
321 |
st.session_state.gallery_focus['tag'] = random.choice(tags)
|
|
|
323 |
prompts = self.promptBook[self.promptBook['tag'] == st.session_state.gallery_focus['tag']]['prompt'].unique()
|
324 |
st.session_state.gallery_focus['prompt'] = random.choice(prompts)
|
325 |
|
|
|
326 |
def graph_mode(self, prompt_id, items):
|
327 |
graph_cols = st.columns([3, 1])
|
328 |
# prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}",
|
|
|
360 |
|
361 |
if checked:
|
362 |
# deselect = st.button('Deselect', key=f'select_{item["prompt_id"]}_{item["modelVersion_id"]}', use_container_width=True)
|
363 |
+
deselect = st.form_submit_button('Deselect', use_container_width=True, on_click=self.image_selection_control, args=(item['tag'], item['prompt'], item['prompt_id'], item['modelVersion_id'], False))
|
|
|
|
|
|
|
|
|
364 |
|
365 |
else:
|
366 |
# select = st.button('Select', key=f'select_{item["prompt_id"]}_{item["modelVersion_id"]}', use_container_width=True, type='primary')
|
367 |
+
select = st.form_submit_button('Select', use_container_width=True, type='primary', on_click=self.image_selection_control, args=(item['tag'], item['prompt'], item['prompt_id'], item['modelVersion_id'], True))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
|
369 |
# st.write(item)
|
370 |
infos = ['model_name', 'modelVersion_name', 'model_download_count', 'clip_score', 'mcos_score',
|
|
|
378 |
else:
|
379 |
st.info('Please click on an image to show')
|
380 |
|
381 |
+
def image_selection_control(self, tag, prompt, prompt_id, modelVersion_id, to_select):
|
382 |
+
self.remove_ranking_states(prompt_id)
|
383 |
+
if to_select:
|
384 |
+
st.session_state.selected_dict[prompt_id].append(modelVersion_id)
|
385 |
+
# add focus to session state
|
386 |
+
st.session_state.gallery_focus['tag'] = tag
|
387 |
+
st.session_state.gallery_focus['prompt'] = prompt
|
388 |
+
|
389 |
+
else:
|
390 |
+
st.session_state.selected_dict[prompt_id].remove(modelVersion_id)
|
391 |
+
|
392 |
def checkout_mode(self, tag, items):
|
393 |
# st.write(items)
|
394 |
if len(items) > 0:
|
|
|
412 |
['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id',
|
413 |
'total_score', 'model_download_count', 'clip_score', 'mcos_score',
|
414 |
'norm_nsfw'],
|
415 |
+
label_visibility='collapsed', key=f'info_{prompt_id}', placeholder='Select what info to show')
|
416 |
|
417 |
with checkout_panel[-1]:
|
418 |
checkout_buttons = st.columns([1, 1, 1])
|
|
|
423 |
st.session_state.gallery_focus['prompt'] = prompt
|
424 |
print(st.session_state.gallery_focus)
|
425 |
st.session_state.gallery_state = 'graph'
|
426 |
+
st.rerun()
|
427 |
|
428 |
with checkout_buttons[1]:
|
429 |
# init edit state
|
|
|
434 |
edit = st.button('Edit', key=f'checkout_edit_{prompt_id}', use_container_width=True)
|
435 |
if edit:
|
436 |
st.session_state.edit_state = True
|
437 |
+
st.rerun()
|
438 |
else:
|
439 |
done = st.button('Done', key=f'checkout_done_{prompt_id}', use_container_width=True)
|
440 |
if done:
|
|
|
448 |
st.session_state.selected_dict[prompt_id].append(int(keys[2]))
|
449 |
|
450 |
st.session_state.edit_state = False
|
451 |
+
st.rerun()
|
452 |
|
453 |
with checkout_buttons[-1]:
|
454 |
proceed = st.button('Proceed ➡️', key=f'checkout_proceed_{prompt_id}', use_container_width=True,
|
|
|
457 |
st.session_state.gallery_focus['tag'] = tag
|
458 |
st.session_state.gallery_focus['prompt'] = prompt
|
459 |
st.session_state.gallery_state = 'graph'
|
460 |
+
|
461 |
+
print('selected_dict: ', st.session_state.selected_dict)
|
462 |
+
|
463 |
+
# # save the user selection to database
|
464 |
+
cursor = GALLERY_CONN.cursor()
|
465 |
+
st.session_state.epoch['gallery'] += 1
|
466 |
+
checkouttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
467 |
+
# for modelVersion_id in st.session_state.selected_dict[prompt_id]:
|
468 |
+
for key, values in st.session_state.selected_dict.items():
|
469 |
+
# print('key: ', key, 'values: ', values)
|
470 |
+
key_tag = self.promptBook[self.promptBook['prompt_id'] == key]['tag'].unique()[0]
|
471 |
+
for value in values:
|
472 |
+
query = "INSERT INTO gallery_selections (username, timestamp, tag, prompt_id, modelVersion_id, checkouttime, epoch) VALUES ('{}', '{}', '{}', '{}', {}, '{}', {})".format(st.session_state.user_id[0], st.session_state.user_id[1], key_tag, key, value, checkouttime, st.session_state.epoch['gallery'])
|
473 |
+
print(query)
|
474 |
+
cursor.execute(query)
|
475 |
+
GALLERY_CONN.commit()
|
476 |
+
cursor.close()
|
477 |
+
|
478 |
+
# get the largest epoch number of this user and prompt
|
479 |
+
cursor = GALLERY_CONN.cursor()
|
480 |
+
db_table = 'battle_results' if st.session_state.assigned_rank_mode=='Battle' else 'sort_results'
|
481 |
+
query = "SELECT MAX(epoch) FROM {} WHERE username = '{}' AND timestamp = '{}' AND prompt_id = {}".format(db_table, st.session_state.user_id[0], st.session_state.user_id[1], prompt_id)
|
482 |
+
cursor.execute(query)
|
483 |
+
max_epoch = cursor.fetchone()['MAX(epoch)'],
|
484 |
+
# print('max epoch: ', max_epoch, type(max_epoch))
|
485 |
+
cursor.close()
|
486 |
+
|
487 |
+
try:
|
488 |
+
st.session_state.epoch['ranking'][prompt_id] = max_epoch[0] + 1
|
489 |
+
except TypeError:
|
490 |
+
st.session_state.epoch['ranking'][prompt_id] = 1
|
491 |
+
# st.session_state.epoch['summary'][tag] = st.session_state.epoch['summary'].get(tag, 0) + 1
|
492 |
+
# st.session_state.epoch['summary']['overall'] += 1
|
493 |
+
print('epoch: ', st.session_state.epoch)
|
494 |
switch_page('ranking')
|
495 |
|
496 |
self.gallery_standard(items[items['prompt_id'] == prompt_id].reset_index(drop=True), 4, info, show_checkbox=st.session_state.edit_state)
|
|
|
502 |
st.session_state.gallery_focus['tag'] = tag
|
503 |
st.session_state.gallery_focus['prompt'] = None
|
504 |
st.session_state.gallery_state = 'graph'
|
505 |
+
st.rerun()
|
506 |
|
507 |
def remove_ranking_states(self, prompt_id):
|
508 |
# for drag sort
|
|
|
586 |
if home_btn:
|
587 |
switch_page("home")
|
588 |
else:
|
589 |
+
GALLERY_CONN = connect_to_db()
|
590 |
roster, promptBook, images_ds = load_hf_dataset(st.session_state.show_NSFW)
|
591 |
|
592 |
app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
|
pages/Ranking.py
CHANGED
@@ -11,6 +11,7 @@ from streamlit_elements import elements, mui, html, dashboard, nivo
|
|
11 |
from streamlit_extras.switch_page_button import switch_page
|
12 |
|
13 |
from pages.Gallery import load_hf_dataset
|
|
|
14 |
|
15 |
|
16 |
class RankingApp:
|
@@ -122,7 +123,7 @@ class RankingApp:
|
|
122 |
for k in st.session_state.ranking[prompt_id][batch_idx].keys():
|
123 |
st.session_state.ranking[prompt_id][batch_idx][k] = sorted_list.index(k)
|
124 |
|
125 |
-
def dragsort_mode(self, items, prompt_id, batch_num):
|
126 |
st.session_state.counter[prompt_id] = 0 if prompt_id not in st.session_state.counter else \
|
127 |
st.session_state.counter[prompt_id]
|
128 |
|
@@ -149,16 +150,16 @@ class RankingApp:
|
|
149 |
with control:
|
150 |
if st.session_state.counter[prompt_id] < batch_num - 1:
|
151 |
st.button(":arrow_right:", key='next', on_click=self.next_batch, help='Next Batch',
|
152 |
-
kwargs={'prompt_id': prompt_id}, use_container_width=True, type='primary')
|
153 |
else:
|
154 |
st.button(":heavy_check_mark:", key='finished', on_click=self.next_batch, help='Finished',
|
155 |
-
kwargs={'prompt_id': prompt_id, 'progress': 'finished'}, use_container_width=True, type='primary')
|
156 |
|
157 |
if st.session_state.counter[prompt_id] > 0:
|
158 |
st.button(":arrow_left:", key='prev', on_click=self.prev_batch, help='Previous Batch',
|
159 |
kwargs={'prompt_id': prompt_id}, use_container_width=True)
|
160 |
|
161 |
-
def next_batch(self, prompt_id, progress=None):
|
162 |
|
163 |
curser = RANKING_CONN.cursor()
|
164 |
|
@@ -168,27 +169,36 @@ class RankingApp:
|
|
168 |
modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_id]['modelVersion_id'].values[0]
|
169 |
position_version_dict[position] = modelVersion_id
|
170 |
|
171 |
-
# get all records of this user and prompt
|
172 |
-
query = "SELECT * FROM sort_results WHERE username = %s AND timestamp = %s AND prompt_id = %s"
|
173 |
-
curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id))
|
174 |
-
results = curser.fetchall()
|
175 |
-
print(results)
|
176 |
-
|
177 |
-
# remove the old ranking with the same modelVersion_id if exists
|
178 |
-
for result in results:
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
187 |
|
188 |
curser.close()
|
189 |
RANKING_CONN.commit()
|
190 |
|
191 |
if progress == 'finished':
|
|
|
|
|
|
|
|
|
|
|
192 |
st.session_state.progress[prompt_id] = 'finished'
|
193 |
# drop 'modelVersion_standings' from session state if exists
|
194 |
st.session_state.pop('modelVersion_standings', None)
|
@@ -198,7 +208,7 @@ class RankingApp:
|
|
198 |
def prev_batch(self, prompt_id):
|
199 |
st.session_state.counter[prompt_id] -= 1
|
200 |
|
201 |
-
def battle_images(self, items, prompt_id):
|
202 |
if 'pointer' not in st.session_state:
|
203 |
st.session_state.pointer = {}
|
204 |
|
@@ -221,7 +231,7 @@ class RankingApp:
|
|
221 |
# total_score = items['total_score'][st.session_state.pointer[prompt_id]['left']]
|
222 |
# st.write(f'Total Score: {total_score}')
|
223 |
|
224 |
-
btn_left = st.button('Left is better', key='left', on_click=self.next_battle, kwargs={'prompt_id': prompt_id, 'image_ids': items['image_id'], 'winner': 'left', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
|
225 |
st.image(img_url, use_column_width=True)
|
226 |
|
227 |
with right:
|
@@ -232,10 +242,10 @@ class RankingApp:
|
|
232 |
# total_score = items['total_score'][st.session_state.pointer[prompt_id]['right']]
|
233 |
# st.write(f'Total Score: {total_score}')
|
234 |
|
235 |
-
btn_right = st.button('Right is better', key='right', on_click=self.next_battle, kwargs={'prompt_id': prompt_id, 'image_ids': items['image_id'], 'winner': 'right', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
|
236 |
st.image(img_url, use_column_width=True)
|
237 |
|
238 |
-
def next_battle(self, prompt_id, image_ids, winner, curr_position, total_num):
|
239 |
loser = 'left' if winner == 'right' else 'right'
|
240 |
battletime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
241 |
|
@@ -244,19 +254,24 @@ class RankingApp:
|
|
244 |
winner_modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_ids[st.session_state.pointer[prompt_id][winner]]]['modelVersion_id'].values[0]
|
245 |
loser_modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_ids[st.session_state.pointer[prompt_id][loser]]]['modelVersion_id'].values[0]
|
246 |
|
247 |
-
# remove the old battle result if exists
|
248 |
-
query = "DELETE FROM battle_results WHERE username = %s AND timestamp = %s AND prompt_id = %s AND winner = %s AND loser = %s"
|
249 |
-
curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, winner_modelVersion_id, loser_modelVersion_id))
|
250 |
-
curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, loser_modelVersion_id, winner_modelVersion_id))
|
251 |
|
252 |
# insert the battle result into the database
|
253 |
-
query = "INSERT INTO battle_results (username, timestamp, tag, prompt_id, winner, loser, battletime) VALUES (%s, %s, %s, %s, %s, %s, %s)"
|
254 |
-
curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], self.promptBook[self.promptBook['prompt_id'] == prompt_id]['tag'].values[0], prompt_id, winner_modelVersion_id, loser_modelVersion_id, battletime))
|
255 |
|
256 |
curser.close()
|
257 |
RANKING_CONN.commit()
|
258 |
|
259 |
if curr_position == total_num - 1:
|
|
|
|
|
|
|
|
|
|
|
260 |
st.session_state.progress[prompt_id] = 'finished'
|
261 |
|
262 |
# drop 'modelVersion_standings' from session state if exists
|
@@ -266,8 +281,8 @@ class RankingApp:
|
|
266 |
else:
|
267 |
st.session_state.pointer[prompt_id][loser] = curr_position + 1
|
268 |
|
269 |
-
def battle_mode(self, items, prompt_id):
|
270 |
-
self.battle_images(items, prompt_id)
|
271 |
|
272 |
def app(self):
|
273 |
st.write('### Generative Model Ranking')
|
@@ -288,15 +303,43 @@ class RankingApp:
|
|
288 |
'tag'] in prompt_tags else 0
|
289 |
print(tag_idx)
|
290 |
|
291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
|
|
|
|
|
293 |
prompts = np.sort(items['prompt'].unique())[::-1].tolist()
|
294 |
|
295 |
prompt_idx = prompts.index(st.session_state.gallery_focus['prompt']) if st.session_state.gallery_focus[
|
296 |
'prompt'] in prompts else 0
|
297 |
print(prompt_idx)
|
298 |
|
299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
|
301 |
# mode = st.radio('Select a mode', ['Drag and Sort', 'Battle'], index=1)
|
302 |
mode = st.session_state.assigned_rank_mode
|
@@ -317,13 +360,15 @@ class RankingApp:
|
|
317 |
st.session_state.progress[prompt_id] = 'ranking'
|
318 |
|
319 |
if st.session_state.progress[prompt_id] == 'ranking':
|
|
|
320 |
st.caption("We might pair some other images that you haven't selected based on our evaluation matrix.")
|
321 |
if mode == 'Drag and Sort':
|
322 |
-
self.dragsort_mode(items, prompt_id, batch_num)
|
323 |
elif mode == 'Battle':
|
324 |
-
self.battle_mode(items, prompt_id)
|
325 |
|
326 |
elif st.session_state.progress[prompt_id] == 'finished':
|
|
|
327 |
# st.toast('**Summary is available now!**')
|
328 |
# st.write('---')
|
329 |
with st.form(key='ranking_finished'):
|
@@ -348,25 +393,45 @@ class RankingApp:
|
|
348 |
st.form_submit_button('👆 Rank other prompts', use_container_width=True)
|
349 |
|
350 |
with options_panel[3]:
|
351 |
-
restart_btn = st.form_submit_button('🎖️ Re-rank this prompt', use_container_width=True)
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
|
372 |
if __name__ == "__main__":
|
|
|
11 |
from streamlit_extras.switch_page_button import switch_page
|
12 |
|
13 |
from pages.Gallery import load_hf_dataset
|
14 |
+
from Home import connect_to_db
|
15 |
|
16 |
|
17 |
class RankingApp:
|
|
|
123 |
for k in st.session_state.ranking[prompt_id][batch_idx].keys():
|
124 |
st.session_state.ranking[prompt_id][batch_idx][k] = sorted_list.index(k)
|
125 |
|
126 |
+
def dragsort_mode(self, tag, items, prompt_id, batch_num):
|
127 |
st.session_state.counter[prompt_id] = 0 if prompt_id not in st.session_state.counter else \
|
128 |
st.session_state.counter[prompt_id]
|
129 |
|
|
|
150 |
with control:
|
151 |
if st.session_state.counter[prompt_id] < batch_num - 1:
|
152 |
st.button(":arrow_right:", key='next', on_click=self.next_batch, help='Next Batch',
|
153 |
+
kwargs={'tag': tag, 'prompt_id': prompt_id}, use_container_width=True, type='primary')
|
154 |
else:
|
155 |
st.button(":heavy_check_mark:", key='finished', on_click=self.next_batch, help='Finished',
|
156 |
+
kwargs={'tag': tag, 'prompt_id': prompt_id, 'progress': 'finished'}, use_container_width=True, type='primary')
|
157 |
|
158 |
if st.session_state.counter[prompt_id] > 0:
|
159 |
st.button(":arrow_left:", key='prev', on_click=self.prev_batch, help='Previous Batch',
|
160 |
kwargs={'prompt_id': prompt_id}, use_container_width=True)
|
161 |
|
162 |
+
def next_batch(self, tag, prompt_id, progress=None):
|
163 |
|
164 |
curser = RANKING_CONN.cursor()
|
165 |
|
|
|
169 |
modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_id]['modelVersion_id'].values[0]
|
170 |
position_version_dict[position] = modelVersion_id
|
171 |
|
172 |
+
# # get all records of this user and prompt
|
173 |
+
# query = "SELECT * FROM sort_results WHERE username = %s AND timestamp = %s AND prompt_id = %s"
|
174 |
+
# curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id))
|
175 |
+
# results = curser.fetchall()
|
176 |
+
# print(results)
|
177 |
+
#
|
178 |
+
# # remove the old ranking with the same modelVersion_id if exists
|
179 |
+
# for result in results:
|
180 |
+
# prev_ids = [result['position1'], result['position2'], result['position3'], result['position4']]
|
181 |
+
# curr_ids = [position_version_dict[0], position_version_dict[1], position_version_dict[2], position_version_dict[3]]
|
182 |
+
# if len(set(prev_ids).intersection(set(curr_ids))) == 4:
|
183 |
+
# query = "DELETE FROM sort_results WHERE username = %s AND timestamp = %s AND prompt_id = %s AND position1 = %s AND position2 = %s AND position3 = %s AND position4 = %s"
|
184 |
+
# curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, result['position1'], result['position2'], result['position3'], result['position4']))
|
185 |
+
|
186 |
+
sorttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
187 |
+
# handle the case where user press the 'prev' button
|
188 |
+
query = "DELETE FROM sort_results WHERE username = %s AND timestamp = %s AND prompt_id = %s AND epoch = %s"
|
189 |
+
curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, st.session_state.epoch['ranking'][prompt_id]))
|
190 |
+
query = "INSERT INTO sort_results (username, timestamp, tag, prompt_id, position1, position2, position3, position4, sorttime, epoch) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)"
|
191 |
+
curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], self.promptBook[self.promptBook['prompt_id'] == prompt_id]['tag'].values[0], prompt_id, position_version_dict[0], position_version_dict[1], position_version_dict[2], position_version_dict[3], sorttime, st.session_state.epoch['ranking'][prompt_id]))
|
192 |
|
193 |
curser.close()
|
194 |
RANKING_CONN.commit()
|
195 |
|
196 |
if progress == 'finished':
|
197 |
+
st.session_state.epoch['ranking'][prompt_id] += 1
|
198 |
+
|
199 |
+
st.session_state.epoch['summary'][tag] = st.session_state.epoch['summary'].get(tag, 0) + 1
|
200 |
+
st.session_state.epoch['summary']['overall'] += 1
|
201 |
+
|
202 |
st.session_state.progress[prompt_id] = 'finished'
|
203 |
# drop 'modelVersion_standings' from session state if exists
|
204 |
st.session_state.pop('modelVersion_standings', None)
|
|
|
208 |
def prev_batch(self, prompt_id):
|
209 |
st.session_state.counter[prompt_id] -= 1
|
210 |
|
211 |
+
def battle_images(self, tag, items, prompt_id):
|
212 |
if 'pointer' not in st.session_state:
|
213 |
st.session_state.pointer = {}
|
214 |
|
|
|
231 |
# total_score = items['total_score'][st.session_state.pointer[prompt_id]['left']]
|
232 |
# st.write(f'Total Score: {total_score}')
|
233 |
|
234 |
+
btn_left = st.button('Left is better', key='left', on_click=self.next_battle, kwargs={'tag': tag, 'prompt_id': prompt_id, 'image_ids': items['image_id'], 'winner': 'left', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
|
235 |
st.image(img_url, use_column_width=True)
|
236 |
|
237 |
with right:
|
|
|
242 |
# total_score = items['total_score'][st.session_state.pointer[prompt_id]['right']]
|
243 |
# st.write(f'Total Score: {total_score}')
|
244 |
|
245 |
+
btn_right = st.button('Right is better', key='right', on_click=self.next_battle, kwargs={'tag': tag, 'prompt_id': prompt_id, 'image_ids': items['image_id'], 'winner': 'right', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
|
246 |
st.image(img_url, use_column_width=True)
|
247 |
|
248 |
+
def next_battle(self, tag, prompt_id, image_ids, winner, curr_position, total_num):
|
249 |
loser = 'left' if winner == 'right' else 'right'
|
250 |
battletime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
251 |
|
|
|
254 |
winner_modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_ids[st.session_state.pointer[prompt_id][winner]]]['modelVersion_id'].values[0]
|
255 |
loser_modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_ids[st.session_state.pointer[prompt_id][loser]]]['modelVersion_id'].values[0]
|
256 |
|
257 |
+
# # remove the old battle result if exists
|
258 |
+
# query = "DELETE FROM battle_results WHERE username = %s AND timestamp = %s AND prompt_id = %s AND winner = %s AND loser = %s"
|
259 |
+
# curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, winner_modelVersion_id, loser_modelVersion_id))
|
260 |
+
# curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, loser_modelVersion_id, winner_modelVersion_id))
|
261 |
|
262 |
# insert the battle result into the database
|
263 |
+
query = "INSERT INTO battle_results (username, timestamp, tag, prompt_id, winner, loser, battletime, epoch) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)"
|
264 |
+
curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], self.promptBook[self.promptBook['prompt_id'] == prompt_id]['tag'].values[0], prompt_id, winner_modelVersion_id, loser_modelVersion_id, battletime, st.session_state.epoch['ranking'][prompt_id]))
|
265 |
|
266 |
curser.close()
|
267 |
RANKING_CONN.commit()
|
268 |
|
269 |
if curr_position == total_num - 1:
|
270 |
+
st.session_state.epoch['ranking'][prompt_id] += 1
|
271 |
+
|
272 |
+
st.session_state.epoch['summary'][tag] = st.session_state.epoch['summary'].get(tag, 0) + 1
|
273 |
+
st.session_state.epoch['summary']['overall'] += 1
|
274 |
+
|
275 |
st.session_state.progress[prompt_id] = 'finished'
|
276 |
|
277 |
# drop 'modelVersion_standings' from session state if exists
|
|
|
281 |
else:
|
282 |
st.session_state.pointer[prompt_id][loser] = curr_position + 1
|
283 |
|
284 |
+
def battle_mode(self, tag, items, prompt_id):
|
285 |
+
self.battle_images(tag, items, prompt_id)
|
286 |
|
287 |
def app(self):
|
288 |
st.write('### Generative Model Ranking')
|
|
|
303 |
'tag'] in prompt_tags else 0
|
304 |
print(tag_idx)
|
305 |
|
306 |
+
# color the finished tags
|
307 |
+
finished_tags = []
|
308 |
+
for tag in prompt_tags:
|
309 |
+
append_tag = True
|
310 |
+
for prompt_id in self.promptBook[self.promptBook['tag'] == tag]['prompt_id'].unique():
|
311 |
+
if prompt_id not in st.session_state.progress or st.session_state.progress[prompt_id] != 'finished':
|
312 |
+
append_tag = False
|
313 |
+
break
|
314 |
+
if append_tag:
|
315 |
+
finished_tags.append(tag)
|
316 |
+
tag_tobe_colored = self.promptBook[self.promptBook['tag'].isin(finished_tags)]['tag'].unique().tolist()
|
317 |
+
colored_tags = self.text_coloring_add(tag_tobe_colored, prompt_tags, color_name='orange')
|
318 |
+
|
319 |
+
tag = st.radio('Select a tag', colored_tags, index=tag_idx, horizontal=True, label_visibility='collapsed')
|
320 |
+
tag = self.text_coloring_remove(tag)
|
321 |
+
print(tag)
|
322 |
items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
|
323 |
+
# pick out prompts such that st.session_state.progress[prompt_id] == 'finished'
|
324 |
+
|
325 |
prompts = np.sort(items['prompt'].unique())[::-1].tolist()
|
326 |
|
327 |
prompt_idx = prompts.index(st.session_state.gallery_focus['prompt']) if st.session_state.gallery_focus[
|
328 |
'prompt'] in prompts else 0
|
329 |
print(prompt_idx)
|
330 |
|
331 |
+
# color the finished prompts
|
332 |
+
finished_prompts = []
|
333 |
+
for prompt_id in items['prompt_id'].unique():
|
334 |
+
if prompt_id in st.session_state.progress and st.session_state.progress[prompt_id] == 'finished':
|
335 |
+
finished_prompts.append(prompt_id)
|
336 |
+
prompt_tobe_colored = items[items['prompt_id'].isin(finished_prompts)]['prompt'].unique().tolist()
|
337 |
+
colored_prompts = self.text_coloring_add(prompt_tobe_colored, prompts, color_name='✅')
|
338 |
+
|
339 |
+
selected_prompt = st.selectbox('Select a prompt', colored_prompts, index=prompt_idx, label_visibility='collapsed')
|
340 |
+
selected_prompt = self.text_coloring_remove(selected_prompt)
|
341 |
+
|
342 |
+
st.session_state.gallery_focus = {'tag': tag, 'prompt': selected_prompt}
|
343 |
|
344 |
# mode = st.radio('Select a mode', ['Drag and Sort', 'Battle'], index=1)
|
345 |
mode = st.session_state.assigned_rank_mode
|
|
|
360 |
st.session_state.progress[prompt_id] = 'ranking'
|
361 |
|
362 |
if st.session_state.progress[prompt_id] == 'ranking':
|
363 |
+
st.session_state.epoch['ranking'][prompt_id] = st.session_state.epoch['ranking'].get(prompt_id, 1)
|
364 |
st.caption("We might pair some other images that you haven't selected based on our evaluation matrix.")
|
365 |
if mode == 'Drag and Sort':
|
366 |
+
self.dragsort_mode(tag, items, prompt_id, batch_num)
|
367 |
elif mode == 'Battle':
|
368 |
+
self.battle_mode(tag, items, prompt_id)
|
369 |
|
370 |
elif st.session_state.progress[prompt_id] == 'finished':
|
371 |
+
print(st.session_state.gallery_focus)
|
372 |
# st.toast('**Summary is available now!**')
|
373 |
# st.write('---')
|
374 |
with st.form(key='ranking_finished'):
|
|
|
393 |
st.form_submit_button('👆 Rank other prompts', use_container_width=True)
|
394 |
|
395 |
with options_panel[3]:
|
396 |
+
restart_btn = st.form_submit_button('🎖️ Re-rank this prompt', use_container_width=True, on_click=self.rerank, kwargs={'prompt_id': prompt_id})
|
397 |
+
|
398 |
+
# with st.sidebar:
|
399 |
+
# st.write('epoch: ', st.session_state.epoch['ranking'][prompt_id])
|
400 |
+
def text_coloring_add(self, tobe_colored:list, total_items, color_name='orange'):
|
401 |
+
if color_name in ['orange', 'red', 'green', 'blue', 'violet', 'yellow']:
|
402 |
+
colored = [f':{color_name}[{item}]' if item in tobe_colored else item for item in total_items]
|
403 |
+
else:
|
404 |
+
colored = [f'[{color_name}] {item}' if item in tobe_colored else item for item in total_items]
|
405 |
+
return colored
|
406 |
+
|
407 |
+
def text_coloring_remove(self, tobe_removed):
|
408 |
+
if isinstance(tobe_removed, str):
|
409 |
+
if tobe_removed.startswith(':'):
|
410 |
+
tobe_removed = tobe_removed.split('[')[-1][:-1]
|
411 |
+
|
412 |
+
elif tobe_removed.startswith('['):
|
413 |
+
tobe_removed = tobe_removed.split(']')[-1][1:]
|
414 |
+
return tobe_removed
|
415 |
+
|
416 |
+
def rerank(self, prompt_id):
|
417 |
+
st.session_state.progress[prompt_id] = 'ranking'
|
418 |
+
if st.session_state.assigned_rank_mode == 'Drag and Sort':
|
419 |
+
st.session_state.counter[prompt_id] = 0
|
420 |
+
elif st.session_state.assigned_rank_mode == 'Battle':
|
421 |
+
st.session_state.pointer[prompt_id] = {'left': 0, 'right': 1}
|
422 |
+
|
423 |
+
# def connect_to_db():
|
424 |
+
# conn = pymysql.connect(
|
425 |
+
# host=os.environ.get('RANKING_DB_HOST'),
|
426 |
+
# port=3306,
|
427 |
+
# database='myRanking',
|
428 |
+
# user=os.environ.get('RANKING_DB_USER'),
|
429 |
+
# password=os.environ.get('RANKING_DB_PASSWORD'),
|
430 |
+
# charset='utf8mb4',
|
431 |
+
# cursorclass=pymysql.cursors.DictCursor
|
432 |
+
# )
|
433 |
+
#
|
434 |
+
# return conn
|
435 |
|
436 |
|
437 |
if __name__ == "__main__":
|
pages/Summary.py
CHANGED
@@ -15,7 +15,7 @@ from streamlit_extras.stylable_container import stylable_container
|
|
15 |
from st_clickable_images import clickable_images
|
16 |
|
17 |
from pages.Gallery import load_hf_dataset
|
18 |
-
from
|
19 |
|
20 |
|
21 |
class DashboardApp:
|
@@ -42,35 +42,52 @@ class DashboardApp:
|
|
42 |
if back_to_ranking:
|
43 |
switch_page('ranking')
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
|
52 |
|
53 |
def leaderboard(self, tag, db_table):
|
54 |
tag = '%' if tag == 'overview' else tag
|
55 |
|
56 |
-
#
|
|
|
|
|
57 |
curser = RANKING_CONN.cursor()
|
58 |
-
curser.execute(f"SELECT * FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND tag LIKE '{tag}'")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
results = curser.fetchall()
|
60 |
curser.close()
|
61 |
|
|
|
|
|
62 |
if tag not in st.session_state.modelVersion_standings:
|
63 |
st.session_state.modelVersion_standings[tag] = self.score_calculator(results, db_table)
|
64 |
|
65 |
# sort the modelVersion_standings by value into a list of tuples in descending order
|
66 |
st.session_state.modelVersion_standings[tag] = sorted(st.session_state.modelVersion_standings[tag].items(), key=lambda x: x[1], reverse=True)
|
67 |
-
|
68 |
-
# tab1, tab2 = st.tabs(['Top Picks', 'Detailed Info'])
|
69 |
-
|
70 |
-
# with tab1:
|
71 |
-
# self.podium(modelVersion_standings)
|
72 |
-
# switch_stage = st.toggle('Manual Reorder', key='switch_stage')
|
73 |
-
|
74 |
example_prompts = []
|
75 |
# get example images
|
76 |
for key, value in st.session_state.selected_dict.items():
|
@@ -78,19 +95,8 @@ class DashboardApp:
|
|
78 |
if model[0] in value:
|
79 |
example_prompts.append(key)
|
80 |
|
81 |
-
# if switch_stage:
|
82 |
-
# self.podium_expander(tag, n=len(st.session_state.modelVersion_standings[tag]), summary_mode='edit', example_prompts=example_prompts)
|
83 |
-
# else:
|
84 |
self.podium_expander(tag, n=len(st.session_state.modelVersion_standings[tag]), summary_mode='display', example_prompts=example_prompts)
|
85 |
-
# if st.session_state.summary_mode == 'display':
|
86 |
-
# switch_stage = st.button('Manual Reorder', key='switch_stage_edit', on_click=lambda: st.session_state.__setitem__('summary_mode', 'edit'))
|
87 |
-
# self.podium_expander(tag, n=3, summary_mode='display')
|
88 |
-
#
|
89 |
-
# elif st.session_state.summary_mode == 'edit':
|
90 |
-
# switch_stage = st.button('Done', key='switch_stage_done', type='primary', on_click=lambda: st.session_state.__setitem__('summary_mode', 'display'))
|
91 |
-
# self.podium_expander(tag, n=len(st.session_state.modelVersion_standings[tag]), summary_mode='edit')
|
92 |
|
93 |
-
# with tab2:
|
94 |
st.write('---')
|
95 |
st.write('**Detailed information of all selected models**')
|
96 |
detailed_info = pd.merge(pd.DataFrame(st.session_state.modelVersion_standings[tag], columns=['modelVersion_id', 'ranking_score']), self.roster, on='modelVersion_id')
|
@@ -101,6 +107,7 @@ class DashboardApp:
|
|
101 |
st.caption('You can click the header to sort the table by that column.')
|
102 |
|
103 |
def podium_expander(self, tag, example_prompts, n=3, summary_mode: ['display', 'edit'] = 'display'):
|
|
|
104 |
|
105 |
for i in range(n):
|
106 |
modelVersion_id = st.session_state.modelVersion_standings[tag][i][0]
|
@@ -137,7 +144,7 @@ class DashboardApp:
|
|
137 |
example_images = [f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{image}.png" for image in example_images]
|
138 |
clickable_images(
|
139 |
example_images,
|
140 |
-
img_style={"margin": "5px", "height": "120px"}
|
141 |
)
|
142 |
|
143 |
else:
|
@@ -173,19 +180,33 @@ class DashboardApp:
|
|
173 |
if i != n - 1:
|
174 |
st.write('---')
|
175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
def switch_order(self, tag, current, target):
|
177 |
-
# st.session_state.modelVersion_standings[tag][current], st.session_state.modelVersion_standings[tag][target] = st.session_state.modelVersion_standings[tag][target], st.session_state.modelVersion_standings[tag][current]
|
178 |
# insert the current before the target
|
179 |
st.session_state.modelVersion_standings[tag].insert(target, st.session_state.modelVersion_standings[tag].pop(current))
|
180 |
-
|
181 |
-
|
182 |
-
# RANKING_CONN.commit()
|
183 |
-
# curser.close()
|
184 |
curser = RANKING_CONN.cursor()
|
185 |
# clear the current user's ranking results
|
186 |
-
curser.execute(f"DELETE FROM summary_results WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND tag = '{
|
187 |
for i in range(len(st.session_state.modelVersion_standings[tag])):
|
188 |
-
curser.execute(f"INSERT INTO summary_results (username, timestamp, tag, modelVersion_id, position, ranking_score) VALUES ('{st.session_state.user_id[0]}', '{st.session_state.user_id[1]}', '{
|
189 |
RANKING_CONN.commit()
|
190 |
curser.close()
|
191 |
|
@@ -197,7 +218,6 @@ class DashboardApp:
|
|
197 |
|
198 |
for record in results:
|
199 |
modelVersion_standings[record['winner']] = modelVersion_standings.get(record['winner'], 0) + 1
|
200 |
-
|
201 |
# add the loser who never wins
|
202 |
if record['loser'] not in modelVersion_standings:
|
203 |
modelVersion_standings[record['loser']] = 0
|
@@ -221,7 +241,7 @@ class DashboardApp:
|
|
221 |
# get tags from database of the current user
|
222 |
db_table = 'sort_results' if mode == 'Drag and Sort' else 'battle_results'
|
223 |
|
224 |
-
tags = [
|
225 |
curser = RANKING_CONN.cursor()
|
226 |
curser.execute(
|
227 |
f"SELECT DISTINCT tag FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}'")
|
@@ -229,29 +249,30 @@ class DashboardApp:
|
|
229 |
tags.append(row['tag'])
|
230 |
curser.close()
|
231 |
|
232 |
-
if tags ==
|
233 |
st.info(f'No rankings are finished with {mode} mode yet.')
|
234 |
|
235 |
else:
|
236 |
-
tags = tags[1:2] if len(tags) == 2 else tags
|
|
|
237 |
tag = st.radio('Select a tag', tags, index=0, horizontal=True, label_visibility='collapsed')
|
238 |
self.sidebar(tags, mode)
|
239 |
self.leaderboard(tag, db_table)
|
240 |
|
241 |
-
with st.sidebar:
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
|
256 |
|
257 |
if __name__ == "__main__":
|
|
|
15 |
from st_clickable_images import clickable_images
|
16 |
|
17 |
from pages.Gallery import load_hf_dataset
|
18 |
+
from Home import connect_to_db
|
19 |
|
20 |
|
21 |
class DashboardApp:
|
|
|
42 |
if back_to_ranking:
|
43 |
switch_page('ranking')
|
44 |
|
45 |
+
with st.form('overall_feedback'):
|
46 |
+
comment = st.text_area('Please leave your comments here.', key='comment')
|
47 |
+
submit_feedback = st.form_submit_button('Submit Feedback')
|
48 |
+
if submit_feedback:
|
49 |
+
commenttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
50 |
+
curser = RANKING_CONN.cursor()
|
51 |
+
# parse the comment to at most 300 to avoid SQL injection
|
52 |
+
for i in range(0, len(comment), 300):
|
53 |
+
curser.execute(f"INSERT INTO comments (username, timestamp, comment, commenttime) VALUES ('{st.session_state.user_id[0]}', '{st.session_state.user_id[1]}', '{comment[i:i+300]}', '{commenttime}')")
|
54 |
+
RANKING_CONN.commit()
|
55 |
+
curser.close()
|
56 |
|
57 |
+
st.sidebar.info('🙏 **Thanks for your feedback! We will take it into consideration in our future work.**')
|
58 |
|
59 |
def leaderboard(self, tag, db_table):
|
60 |
tag = '%' if tag == 'overview' else tag
|
61 |
|
62 |
+
# print('tag', tag)
|
63 |
+
|
64 |
+
# get the ranking results of the current user with the lastest epoch
|
65 |
curser = RANKING_CONN.cursor()
|
66 |
+
# curser.execute(f"SELECT * FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND tag LIKE '{tag}'")
|
67 |
+
# curser.execute(f"SELECT * FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND tag LIKE '{tag}' ORDER BY epoch DESC LIMIT 1")
|
68 |
+
curser.execute(
|
69 |
+
f"SELECT * FROM {db_table}\
|
70 |
+
WHERE username = '{st.session_state.user_id[0]}'\
|
71 |
+
AND timestamp = '{st.session_state.user_id[1]}'\
|
72 |
+
AND tag LIKE '{tag}'\
|
73 |
+
AND epoch =\
|
74 |
+
(SELECT MAX(epoch) FROM {db_table}\
|
75 |
+
WHERE username = '{st.session_state.user_id[0]}'\
|
76 |
+
AND timestamp = '{st.session_state.user_id[1]}'\
|
77 |
+
AND tag LIKE '{tag}')")
|
78 |
+
|
79 |
+
|
80 |
results = curser.fetchall()
|
81 |
curser.close()
|
82 |
|
83 |
+
# print('results', results, len(results))
|
84 |
+
|
85 |
if tag not in st.session_state.modelVersion_standings:
|
86 |
st.session_state.modelVersion_standings[tag] = self.score_calculator(results, db_table)
|
87 |
|
88 |
# sort the modelVersion_standings by value into a list of tuples in descending order
|
89 |
st.session_state.modelVersion_standings[tag] = sorted(st.session_state.modelVersion_standings[tag].items(), key=lambda x: x[1], reverse=True)
|
90 |
+
print(st.session_state.modelVersion_standings[tag])
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
example_prompts = []
|
92 |
# get example images
|
93 |
for key, value in st.session_state.selected_dict.items():
|
|
|
95 |
if model[0] in value:
|
96 |
example_prompts.append(key)
|
97 |
|
|
|
|
|
|
|
98 |
self.podium_expander(tag, n=len(st.session_state.modelVersion_standings[tag]), summary_mode='display', example_prompts=example_prompts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
|
|
100 |
st.write('---')
|
101 |
st.write('**Detailed information of all selected models**')
|
102 |
detailed_info = pd.merge(pd.DataFrame(st.session_state.modelVersion_standings[tag], columns=['modelVersion_id', 'ranking_score']), self.roster, on='modelVersion_id')
|
|
|
107 |
st.caption('You can click the header to sort the table by that column.')
|
108 |
|
109 |
def podium_expander(self, tag, example_prompts, n=3, summary_mode: ['display', 'edit'] = 'display'):
|
110 |
+
self.save_summary(tag)
|
111 |
|
112 |
for i in range(n):
|
113 |
modelVersion_id = st.session_state.modelVersion_standings[tag][i][0]
|
|
|
144 |
example_images = [f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{image}.png" for image in example_images]
|
145 |
clickable_images(
|
146 |
example_images,
|
147 |
+
img_style={"margin": "5px", "height": "120px"},
|
148 |
)
|
149 |
|
150 |
else:
|
|
|
180 |
if i != n - 1:
|
181 |
st.write('---')
|
182 |
|
183 |
+
def save_summary(self, tag):
|
184 |
+
# get the lastest summary_results epoch of the current user
|
185 |
+
tag_name = 'overview' if tag == '%' else tag
|
186 |
+
curser = RANKING_CONN.cursor()
|
187 |
+
curser.execute(f"SELECT epoch FROM summary_results WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND tag = '{tag_name}' ORDER BY epoch DESC LIMIT 1")
|
188 |
+
latest_epoch = curser.fetchone()
|
189 |
+
curser.close()
|
190 |
+
# print('latest_epoch',latest_epoch)
|
191 |
+
if latest_epoch is None or latest_epoch['epoch'] < st.session_state.epoch['summary'][tag_name]:
|
192 |
+
# save the current ranking results to the database
|
193 |
+
summarytime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
194 |
+
curser = RANKING_CONN.cursor()
|
195 |
+
for i in range(len(st.session_state.modelVersion_standings[tag])):
|
196 |
+
curser.execute(f"INSERT INTO summary_results (username, timestamp, tag, modelVersion_id, position, ranking_score, summarytime, epoch, customized) VALUES ('{st.session_state.user_id[0]}', '{st.session_state.user_id[1]}', '{tag_name}', '{st.session_state.modelVersion_standings[tag][i][0]}', {i+1}, {st.session_state.modelVersion_standings[tag][i][1]}, '{summarytime}', {st.session_state.epoch['summary'][tag_name]}, 0)")
|
197 |
+
RANKING_CONN.commit()
|
198 |
+
curser.close()
|
199 |
+
|
200 |
def switch_order(self, tag, current, target):
|
|
|
201 |
# insert the current before the target
|
202 |
st.session_state.modelVersion_standings[tag].insert(target, st.session_state.modelVersion_standings[tag].pop(current))
|
203 |
+
tag_name = 'overview' if tag == '%' else tag
|
204 |
+
summarytime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
|
205 |
curser = RANKING_CONN.cursor()
|
206 |
# clear the current user's ranking results
|
207 |
+
curser.execute(f"DELETE FROM summary_results WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND tag = '{tag_name}' AND epoch = {st.session_state.epoch['summary'][tag_name]}")
|
208 |
for i in range(len(st.session_state.modelVersion_standings[tag])):
|
209 |
+
curser.execute(f"INSERT INTO summary_results (username, timestamp, tag, modelVersion_id, position, ranking_score, summarytime, epoch, customized) VALUES ('{st.session_state.user_id[0]}', '{st.session_state.user_id[1]}', '{tag_name}', '{st.session_state.modelVersion_standings[tag][i][0]}', {i+1}, {st.session_state.modelVersion_standings[tag][i][1]}, '{summarytime}', {st.session_state.epoch['summary'][tag_name]}, 1)")
|
210 |
RANKING_CONN.commit()
|
211 |
curser.close()
|
212 |
|
|
|
218 |
|
219 |
for record in results:
|
220 |
modelVersion_standings[record['winner']] = modelVersion_standings.get(record['winner'], 0) + 1
|
|
|
221 |
# add the loser who never wins
|
222 |
if record['loser'] not in modelVersion_standings:
|
223 |
modelVersion_standings[record['loser']] = 0
|
|
|
241 |
# get tags from database of the current user
|
242 |
db_table = 'sort_results' if mode == 'Drag and Sort' else 'battle_results'
|
243 |
|
244 |
+
tags = []
|
245 |
curser = RANKING_CONN.cursor()
|
246 |
curser.execute(
|
247 |
f"SELECT DISTINCT tag FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}'")
|
|
|
249 |
tags.append(row['tag'])
|
250 |
curser.close()
|
251 |
|
252 |
+
if len(tags) == 0:
|
253 |
st.info(f'No rankings are finished with {mode} mode yet.')
|
254 |
|
255 |
else:
|
256 |
+
# tags = tags[1:2] if len(tags) == 2 else tags
|
257 |
+
tag = ['overview'] + tags if len(tags) > 1 else tags
|
258 |
tag = st.radio('Select a tag', tags, index=0, horizontal=True, label_visibility='collapsed')
|
259 |
self.sidebar(tags, mode)
|
260 |
self.leaderboard(tag, db_table)
|
261 |
|
262 |
+
# with st.sidebar:
|
263 |
+
# with st.form('overall_feedback'):
|
264 |
+
# comment = st.text_area('Please leave your comments here.', key='comment')
|
265 |
+
# submit_feedback = st.form_submit_button('Submit Feedback')
|
266 |
+
# if submit_feedback:
|
267 |
+
# commenttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
268 |
+
# curser = RANKING_CONN.cursor()
|
269 |
+
# # parse the comment to at most 300 to avoid SQL injection
|
270 |
+
# for i in range(0, len(comment), 300):
|
271 |
+
# curser.execute(f"INSERT INTO comments (username, timestamp, comment, commenttime) VALUES ('{st.session_state.user_id[0]}', '{st.session_state.user_id[1]}', '{comment[i:i+300]}', '{commenttime}')")
|
272 |
+
# RANKING_CONN.commit()
|
273 |
+
# curser.close()
|
274 |
+
#
|
275 |
+
# st.sidebar.info('🙏 **Thanks for your feedback! We will take it into consideration in our future work.**')
|
276 |
|
277 |
|
278 |
if __name__ == "__main__":
|