Ricercar commited on
Commit
bca72f3
·
1 Parent(s): dba1106

fully connected to database

Browse files

add a counter called 'epoch'
add checked indicator in ranking page
better button logic

Files changed (4) hide show
  1. Home.py +27 -5
  2. pages/Gallery.py +72 -32
  3. pages/Ranking.py +120 -55
  4. pages/Summary.py +74 -53
Home.py CHANGED
@@ -1,7 +1,10 @@
1
- from datetime import datetime
2
- import streamlit as st
3
  import random
4
- import time
 
 
 
 
5
  from streamlit_extras.switch_page_button import switch_page
6
 
7
 
@@ -31,27 +34,32 @@ def login():
31
 
32
 
33
  def save_user_id(user_id):
 
34
  print(user_id)
35
  if not user_id:
36
  user_id = 'anonymous' + str(random.randint(0, 100000))
37
  st.session_state.user_id = [user_id, datetime.now().strftime("%Y-%m-%d %H:%M:%S")]
38
  st.session_state.assigned_rank_mode = random.choice(['Drag and Sort', 'Battle'])
 
39
 
40
 
41
  def logout():
42
  st.session_state.pop('user_id', None)
43
  st.session_state.pop('selected_dict', None)
 
44
  st.session_state.pop('score_weights', None)
45
  st.session_state.pop('gallery_state', None)
46
  st.session_state.pop('edit_state', None)
47
  st.session_state.pop('progress', None)
 
 
48
  st.session_state.pop('gallery_focus', None)
49
  st.session_state.pop('assigned_rank_mode', None)
50
  st.session_state.pop('show_NSFW', None)
51
  st.session_state.pop('modelVersion_standings', None)
52
 
53
 
54
- def info():
55
  with st.sidebar:
56
  st.write('## About')
57
  st.write(
@@ -63,10 +71,24 @@ def info():
63
  )
64
 
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  if __name__ == '__main__':
67
  # print(st.source_util.get_pages('Home.py'))
68
  st.set_page_config(page_title="Login", page_icon="🏠", layout="wide")
69
- info()
70
  st.write('A Research by [MAPS Lab](https://whongyi.github.io/MAPS-research), [NYU Shanghai](https://shanghai.nyu.edu)')
71
  st.title("🙌 Welcome to GEMRec Gallery!")
72
 
 
1
+ import os
 
2
  import random
3
+
4
+ import pymysql.cursors
5
+ import streamlit as st
6
+
7
+ from datetime import datetime
8
  from streamlit_extras.switch_page_button import switch_page
9
 
10
 
 
34
 
35
 
36
  def save_user_id(user_id):
37
+ user_id = user_id[:60]
38
  print(user_id)
39
  if not user_id:
40
  user_id = 'anonymous' + str(random.randint(0, 100000))
41
  st.session_state.user_id = [user_id, datetime.now().strftime("%Y-%m-%d %H:%M:%S")]
42
  st.session_state.assigned_rank_mode = random.choice(['Drag and Sort', 'Battle'])
43
+ st.session_state.epoch = {'gallery': 0, 'ranking': {}, 'summary': {'overall': 0}}
44
 
45
 
46
  def logout():
47
  st.session_state.pop('user_id', None)
48
  st.session_state.pop('selected_dict', None)
49
+ st.session_state.pop('epoch', None)
50
  st.session_state.pop('score_weights', None)
51
  st.session_state.pop('gallery_state', None)
52
  st.session_state.pop('edit_state', None)
53
  st.session_state.pop('progress', None)
54
+ st.session_state.pop('pointer', None)
55
+ st.session_state.pop('counter', None)
56
  st.session_state.pop('gallery_focus', None)
57
  st.session_state.pop('assigned_rank_mode', None)
58
  st.session_state.pop('show_NSFW', None)
59
  st.session_state.pop('modelVersion_standings', None)
60
 
61
 
62
+ def project_info():
63
  with st.sidebar:
64
  st.write('## About')
65
  st.write(
 
71
  )
72
 
73
 
74
+ def connect_to_db():
75
+ conn = pymysql.connect(
76
+ host=os.environ.get('RANKING_DB_HOST'),
77
+ port=3306,
78
+ database='myRanking',
79
+ user=os.environ.get('RANKING_DB_USER'),
80
+ password=os.environ.get('RANKING_DB_PASSWORD'),
81
+ charset='utf8mb4',
82
+ cursorclass=pymysql.cursors.DictCursor
83
+ )
84
+
85
+ return conn
86
+
87
+
88
  if __name__ == '__main__':
89
  # print(st.source_util.get_pages('Home.py'))
90
  st.set_page_config(page_title="Login", page_icon="🏠", layout="wide")
91
+ project_info()
92
  st.write('A Research by [MAPS Lab](https://whongyi.github.io/MAPS-research), [NYU Shanghai](https://shanghai.nyu.edu)')
93
  st.title("🙌 Welcome to GEMRec Gallery!")
94
 
pages/Gallery.py CHANGED
@@ -12,6 +12,7 @@ import streamlit.components.v1 as components
12
 
13
  from bs4 import BeautifulSoup
14
  from datasets import load_dataset, Dataset, load_from_disk
 
15
  from huggingface_hub import login
16
  from streamlit_agraph import agraph, Node, Edge, Config
17
  from streamlit_extras.switch_page_button import switch_page
@@ -19,6 +20,7 @@ from streamlit_extras.tags import tagger_component
19
  from streamlit_extras.no_default_selectbox import selectbox
20
  from sklearn.svm import LinearSVC
21
 
 
22
 
23
  class GalleryApp:
24
  def __init__(self, promptBook, images_ds):
@@ -203,7 +205,7 @@ class GalleryApp:
203
 
204
  # remove coloring from tag
205
  tag = self.text_coloring_remove(tag)
206
- print('tag: ', tag)
207
 
208
  # print('current state: ', st.session_state.gallery_state)
209
 
@@ -237,7 +239,7 @@ class GalleryApp:
237
 
238
  # remove coloring from prompt
239
  selected_prompt = self.text_coloring_remove(selected_prompt)
240
- print('selected_prompt: ', selected_prompt)
241
  st.session_state.prompt_idx_last_time = prompts.index(selected_prompt) if selected_prompt else 0
242
 
243
  if selected_prompt is None:
@@ -282,15 +284,15 @@ class GalleryApp:
282
  pass
283
 
284
  if has_selection:
285
- checkout = st.button('Check out selections ➡️', use_container_width=True, type='primary')
286
- if checkout:
287
- # add focus to session state
288
- st.session_state.gallery_focus['tag'] = tag
289
- st.session_state.gallery_focus['prompt'] = selected_prompt
290
-
291
- st.session_state.gallery_state = 'check out'
292
- # print(st.session_state.gallery_state)
293
- st.experimental_rerun()
294
  else:
295
  st.button(':orange[👇 **Select images you like below**]', disabled=True, use_container_width=True)
296
  try:
@@ -308,6 +310,12 @@ class GalleryApp:
308
 
309
  self.checkout_mode(tag, items)
310
 
 
 
 
 
 
 
311
 
312
  def random_gallery_focus(self, tags):
313
  st.session_state.gallery_focus['tag'] = random.choice(tags)
@@ -315,7 +323,6 @@ class GalleryApp:
315
  prompts = self.promptBook[self.promptBook['tag'] == st.session_state.gallery_focus['tag']]['prompt'].unique()
316
  st.session_state.gallery_focus['prompt'] = random.choice(prompts)
317
 
318
-
319
  def graph_mode(self, prompt_id, items):
320
  graph_cols = st.columns([3, 1])
321
  # prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}",
@@ -353,24 +360,11 @@ class GalleryApp:
353
 
354
  if checked:
355
  # deselect = st.button('Deselect', key=f'select_{item["prompt_id"]}_{item["modelVersion_id"]}', use_container_width=True)
356
- deselect = st.form_submit_button('Deselect', use_container_width=True)
357
- if deselect:
358
- st.session_state.selected_dict[item['prompt_id']].remove(item['modelVersion_id'])
359
- self.remove_ranking_states(item['prompt_id'])
360
- st.experimental_rerun()
361
 
362
  else:
363
  # select = st.button('Select', key=f'select_{item["prompt_id"]}_{item["modelVersion_id"]}', use_container_width=True, type='primary')
364
- select = st.form_submit_button('Select', use_container_width=True, type='primary')
365
- if select:
366
- st.session_state.selected_dict[item['prompt_id']].append(item['modelVersion_id'])
367
- self.remove_ranking_states(item['prompt_id'])
368
-
369
- # add focus to session state
370
- st.session_state.gallery_focus['tag'] = item['tag']
371
- st.session_state.gallery_focus['prompt'] = item['prompt']
372
-
373
- st.experimental_rerun()
374
 
375
  # st.write(item)
376
  infos = ['model_name', 'modelVersion_name', 'model_download_count', 'clip_score', 'mcos_score',
@@ -384,6 +378,17 @@ class GalleryApp:
384
  else:
385
  st.info('Please click on an image to show')
386
 
 
 
 
 
 
 
 
 
 
 
 
387
  def checkout_mode(self, tag, items):
388
  # st.write(items)
389
  if len(items) > 0:
@@ -407,7 +412,7 @@ class GalleryApp:
407
  ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id',
408
  'total_score', 'model_download_count', 'clip_score', 'mcos_score',
409
  'norm_nsfw'],
410
- label_visibility='collapsed', key=f'info_{prompt_id}', placeholder='Select what infos to show')
411
 
412
  with checkout_panel[-1]:
413
  checkout_buttons = st.columns([1, 1, 1])
@@ -418,7 +423,7 @@ class GalleryApp:
418
  st.session_state.gallery_focus['prompt'] = prompt
419
  print(st.session_state.gallery_focus)
420
  st.session_state.gallery_state = 'graph'
421
- st.experimental_rerun()
422
 
423
  with checkout_buttons[1]:
424
  # init edit state
@@ -429,7 +434,7 @@ class GalleryApp:
429
  edit = st.button('Edit', key=f'checkout_edit_{prompt_id}', use_container_width=True)
430
  if edit:
431
  st.session_state.edit_state = True
432
- st.experimental_rerun()
433
  else:
434
  done = st.button('Done', key=f'checkout_done_{prompt_id}', use_container_width=True)
435
  if done:
@@ -443,7 +448,7 @@ class GalleryApp:
443
  st.session_state.selected_dict[prompt_id].append(int(keys[2]))
444
 
445
  st.session_state.edit_state = False
446
- st.experimental_rerun()
447
 
448
  with checkout_buttons[-1]:
449
  proceed = st.button('Proceed ➡️', key=f'checkout_proceed_{prompt_id}', use_container_width=True,
@@ -452,6 +457,40 @@ class GalleryApp:
452
  st.session_state.gallery_focus['tag'] = tag
453
  st.session_state.gallery_focus['prompt'] = prompt
454
  st.session_state.gallery_state = 'graph'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  switch_page('ranking')
456
 
457
  self.gallery_standard(items[items['prompt_id'] == prompt_id].reset_index(drop=True), 4, info, show_checkbox=st.session_state.edit_state)
@@ -463,7 +502,7 @@ class GalleryApp:
463
  st.session_state.gallery_focus['tag'] = tag
464
  st.session_state.gallery_focus['prompt'] = None
465
  st.session_state.gallery_state = 'graph'
466
- st.experimental_rerun()
467
 
468
  def remove_ranking_states(self, prompt_id):
469
  # for drag sort
@@ -547,6 +586,7 @@ if __name__ == "__main__":
547
  if home_btn:
548
  switch_page("home")
549
  else:
 
550
  roster, promptBook, images_ds = load_hf_dataset(st.session_state.show_NSFW)
551
 
552
  app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
 
12
 
13
  from bs4 import BeautifulSoup
14
  from datasets import load_dataset, Dataset, load_from_disk
15
+ from datetime import datetime
16
  from huggingface_hub import login
17
  from streamlit_agraph import agraph, Node, Edge, Config
18
  from streamlit_extras.switch_page_button import switch_page
 
20
  from streamlit_extras.no_default_selectbox import selectbox
21
  from sklearn.svm import LinearSVC
22
 
23
+ from Home import connect_to_db
24
 
25
  class GalleryApp:
26
  def __init__(self, promptBook, images_ds):
 
205
 
206
  # remove coloring from tag
207
  tag = self.text_coloring_remove(tag)
208
+ # print('tag: ', tag)
209
 
210
  # print('current state: ', st.session_state.gallery_state)
211
 
 
239
 
240
  # remove coloring from prompt
241
  selected_prompt = self.text_coloring_remove(selected_prompt)
242
+ # print('selected_prompt: ', selected_prompt)
243
  st.session_state.prompt_idx_last_time = prompts.index(selected_prompt) if selected_prompt else 0
244
 
245
  if selected_prompt is None:
 
284
  pass
285
 
286
  if has_selection:
287
+ checkout = st.button('Check out selections ➡️', use_container_width=True, type='primary', on_click=self.switch_to_checkout, args=(tag, selected_prompt))
288
+ # if checkout:
289
+ # # add focus to session state
290
+ # st.session_state.gallery_focus['tag'] = tag
291
+ # st.session_state.gallery_focus['prompt'] = selected_prompt
292
+ #
293
+ # st.session_state.gallery_state = 'check out'
294
+ # # print(st.session_state.gallery_state)
295
+ # st.rerun()
296
  else:
297
  st.button(':orange[👇 **Select images you like below**]', disabled=True, use_container_width=True)
298
  try:
 
310
 
311
  self.checkout_mode(tag, items)
312
 
313
+ def switch_to_checkout(self, tag, selected_prompt):
314
+ # add focus to session state
315
+ st.session_state.gallery_focus['tag'] = tag
316
+ st.session_state.gallery_focus['prompt'] = selected_prompt
317
+
318
+ st.session_state.gallery_state = 'check out'
319
 
320
  def random_gallery_focus(self, tags):
321
  st.session_state.gallery_focus['tag'] = random.choice(tags)
 
323
  prompts = self.promptBook[self.promptBook['tag'] == st.session_state.gallery_focus['tag']]['prompt'].unique()
324
  st.session_state.gallery_focus['prompt'] = random.choice(prompts)
325
 
 
326
  def graph_mode(self, prompt_id, items):
327
  graph_cols = st.columns([3, 1])
328
  # prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}",
 
360
 
361
  if checked:
362
  # deselect = st.button('Deselect', key=f'select_{item["prompt_id"]}_{item["modelVersion_id"]}', use_container_width=True)
363
+ deselect = st.form_submit_button('Deselect', use_container_width=True, on_click=self.image_selection_control, args=(item['tag'], item['prompt'], item['prompt_id'], item['modelVersion_id'], False))
 
 
 
 
364
 
365
  else:
366
  # select = st.button('Select', key=f'select_{item["prompt_id"]}_{item["modelVersion_id"]}', use_container_width=True, type='primary')
367
+ select = st.form_submit_button('Select', use_container_width=True, type='primary', on_click=self.image_selection_control, args=(item['tag'], item['prompt'], item['prompt_id'], item['modelVersion_id'], True))
 
 
 
 
 
 
 
 
 
368
 
369
  # st.write(item)
370
  infos = ['model_name', 'modelVersion_name', 'model_download_count', 'clip_score', 'mcos_score',
 
378
  else:
379
  st.info('Please click on an image to show')
380
 
381
+ def image_selection_control(self, tag, prompt, prompt_id, modelVersion_id, to_select):
382
+ self.remove_ranking_states(prompt_id)
383
+ if to_select:
384
+ st.session_state.selected_dict[prompt_id].append(modelVersion_id)
385
+ # add focus to session state
386
+ st.session_state.gallery_focus['tag'] = tag
387
+ st.session_state.gallery_focus['prompt'] = prompt
388
+
389
+ else:
390
+ st.session_state.selected_dict[prompt_id].remove(modelVersion_id)
391
+
392
  def checkout_mode(self, tag, items):
393
  # st.write(items)
394
  if len(items) > 0:
 
412
  ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id',
413
  'total_score', 'model_download_count', 'clip_score', 'mcos_score',
414
  'norm_nsfw'],
415
+ label_visibility='collapsed', key=f'info_{prompt_id}', placeholder='Select what info to show')
416
 
417
  with checkout_panel[-1]:
418
  checkout_buttons = st.columns([1, 1, 1])
 
423
  st.session_state.gallery_focus['prompt'] = prompt
424
  print(st.session_state.gallery_focus)
425
  st.session_state.gallery_state = 'graph'
426
+ st.rerun()
427
 
428
  with checkout_buttons[1]:
429
  # init edit state
 
434
  edit = st.button('Edit', key=f'checkout_edit_{prompt_id}', use_container_width=True)
435
  if edit:
436
  st.session_state.edit_state = True
437
+ st.rerun()
438
  else:
439
  done = st.button('Done', key=f'checkout_done_{prompt_id}', use_container_width=True)
440
  if done:
 
448
  st.session_state.selected_dict[prompt_id].append(int(keys[2]))
449
 
450
  st.session_state.edit_state = False
451
+ st.rerun()
452
 
453
  with checkout_buttons[-1]:
454
  proceed = st.button('Proceed ➡️', key=f'checkout_proceed_{prompt_id}', use_container_width=True,
 
457
  st.session_state.gallery_focus['tag'] = tag
458
  st.session_state.gallery_focus['prompt'] = prompt
459
  st.session_state.gallery_state = 'graph'
460
+
461
+ print('selected_dict: ', st.session_state.selected_dict)
462
+
463
+ # # save the user selection to database
464
+ cursor = GALLERY_CONN.cursor()
465
+ st.session_state.epoch['gallery'] += 1
466
+ checkouttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
467
+ # for modelVersion_id in st.session_state.selected_dict[prompt_id]:
468
+ for key, values in st.session_state.selected_dict.items():
469
+ # print('key: ', key, 'values: ', values)
470
+ key_tag = self.promptBook[self.promptBook['prompt_id'] == key]['tag'].unique()[0]
471
+ for value in values:
472
+ query = "INSERT INTO gallery_selections (username, timestamp, tag, prompt_id, modelVersion_id, checkouttime, epoch) VALUES ('{}', '{}', '{}', '{}', {}, '{}', {})".format(st.session_state.user_id[0], st.session_state.user_id[1], key_tag, key, value, checkouttime, st.session_state.epoch['gallery'])
473
+ print(query)
474
+ cursor.execute(query)
475
+ GALLERY_CONN.commit()
476
+ cursor.close()
477
+
478
+ # get the largest epoch number of this user and prompt
479
+ cursor = GALLERY_CONN.cursor()
480
+ db_table = 'battle_results' if st.session_state.assigned_rank_mode=='Battle' else 'sort_results'
481
+ query = "SELECT MAX(epoch) FROM {} WHERE username = '{}' AND timestamp = '{}' AND prompt_id = {}".format(db_table, st.session_state.user_id[0], st.session_state.user_id[1], prompt_id)
482
+ cursor.execute(query)
483
+ max_epoch = cursor.fetchone()['MAX(epoch)'],
484
+ # print('max epoch: ', max_epoch, type(max_epoch))
485
+ cursor.close()
486
+
487
+ try:
488
+ st.session_state.epoch['ranking'][prompt_id] = max_epoch[0] + 1
489
+ except TypeError:
490
+ st.session_state.epoch['ranking'][prompt_id] = 1
491
+ # st.session_state.epoch['summary'][tag] = st.session_state.epoch['summary'].get(tag, 0) + 1
492
+ # st.session_state.epoch['summary']['overall'] += 1
493
+ print('epoch: ', st.session_state.epoch)
494
  switch_page('ranking')
495
 
496
  self.gallery_standard(items[items['prompt_id'] == prompt_id].reset_index(drop=True), 4, info, show_checkbox=st.session_state.edit_state)
 
502
  st.session_state.gallery_focus['tag'] = tag
503
  st.session_state.gallery_focus['prompt'] = None
504
  st.session_state.gallery_state = 'graph'
505
+ st.rerun()
506
 
507
  def remove_ranking_states(self, prompt_id):
508
  # for drag sort
 
586
  if home_btn:
587
  switch_page("home")
588
  else:
589
+ GALLERY_CONN = connect_to_db()
590
  roster, promptBook, images_ds = load_hf_dataset(st.session_state.show_NSFW)
591
 
592
  app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
pages/Ranking.py CHANGED
@@ -11,6 +11,7 @@ from streamlit_elements import elements, mui, html, dashboard, nivo
11
  from streamlit_extras.switch_page_button import switch_page
12
 
13
  from pages.Gallery import load_hf_dataset
 
14
 
15
 
16
  class RankingApp:
@@ -122,7 +123,7 @@ class RankingApp:
122
  for k in st.session_state.ranking[prompt_id][batch_idx].keys():
123
  st.session_state.ranking[prompt_id][batch_idx][k] = sorted_list.index(k)
124
 
125
- def dragsort_mode(self, items, prompt_id, batch_num):
126
  st.session_state.counter[prompt_id] = 0 if prompt_id not in st.session_state.counter else \
127
  st.session_state.counter[prompt_id]
128
 
@@ -149,16 +150,16 @@ class RankingApp:
149
  with control:
150
  if st.session_state.counter[prompt_id] < batch_num - 1:
151
  st.button(":arrow_right:", key='next', on_click=self.next_batch, help='Next Batch',
152
- kwargs={'prompt_id': prompt_id}, use_container_width=True, type='primary')
153
  else:
154
  st.button(":heavy_check_mark:", key='finished', on_click=self.next_batch, help='Finished',
155
- kwargs={'prompt_id': prompt_id, 'progress': 'finished'}, use_container_width=True, type='primary')
156
 
157
  if st.session_state.counter[prompt_id] > 0:
158
  st.button(":arrow_left:", key='prev', on_click=self.prev_batch, help='Previous Batch',
159
  kwargs={'prompt_id': prompt_id}, use_container_width=True)
160
 
161
- def next_batch(self, prompt_id, progress=None):
162
 
163
  curser = RANKING_CONN.cursor()
164
 
@@ -168,27 +169,36 @@ class RankingApp:
168
  modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_id]['modelVersion_id'].values[0]
169
  position_version_dict[position] = modelVersion_id
170
 
171
- # get all records of this user and prompt
172
- query = "SELECT * FROM sort_results WHERE username = %s AND timestamp = %s AND prompt_id = %s"
173
- curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id))
174
- results = curser.fetchall()
175
- print(results)
176
-
177
- # remove the old ranking with the same modelVersion_id if exists
178
- for result in results:
179
- prev_ids = [result['position1'], result['position2'], result['position3'], result['position4']]
180
- curr_ids = [position_version_dict[0], position_version_dict[1], position_version_dict[2], position_version_dict[3]]
181
- if len(set(prev_ids).intersection(set(curr_ids))) == 4:
182
- query = "DELETE FROM sort_results WHERE username = %s AND timestamp = %s AND prompt_id = %s AND position1 = %s AND position2 = %s AND position3 = %s AND position4 = %s"
183
- curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, result['position1'], result['position2'], result['position3'], result['position4']))
184
-
185
- query = "INSERT INTO sort_results (username, timestamp, tag, prompt_id, position1, position2, position3, position4) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)"
186
- curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], self.promptBook[self.promptBook['prompt_id'] == prompt_id]['tag'].values[0], prompt_id, position_version_dict[0], position_version_dict[1], position_version_dict[2], position_version_dict[3]))
 
 
 
 
187
 
188
  curser.close()
189
  RANKING_CONN.commit()
190
 
191
  if progress == 'finished':
 
 
 
 
 
192
  st.session_state.progress[prompt_id] = 'finished'
193
  # drop 'modelVersion_standings' from session state if exists
194
  st.session_state.pop('modelVersion_standings', None)
@@ -198,7 +208,7 @@ class RankingApp:
198
  def prev_batch(self, prompt_id):
199
  st.session_state.counter[prompt_id] -= 1
200
 
201
- def battle_images(self, items, prompt_id):
202
  if 'pointer' not in st.session_state:
203
  st.session_state.pointer = {}
204
 
@@ -221,7 +231,7 @@ class RankingApp:
221
  # total_score = items['total_score'][st.session_state.pointer[prompt_id]['left']]
222
  # st.write(f'Total Score: {total_score}')
223
 
224
- btn_left = st.button('Left is better', key='left', on_click=self.next_battle, kwargs={'prompt_id': prompt_id, 'image_ids': items['image_id'], 'winner': 'left', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
225
  st.image(img_url, use_column_width=True)
226
 
227
  with right:
@@ -232,10 +242,10 @@ class RankingApp:
232
  # total_score = items['total_score'][st.session_state.pointer[prompt_id]['right']]
233
  # st.write(f'Total Score: {total_score}')
234
 
235
- btn_right = st.button('Right is better', key='right', on_click=self.next_battle, kwargs={'prompt_id': prompt_id, 'image_ids': items['image_id'], 'winner': 'right', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
236
  st.image(img_url, use_column_width=True)
237
 
238
- def next_battle(self, prompt_id, image_ids, winner, curr_position, total_num):
239
  loser = 'left' if winner == 'right' else 'right'
240
  battletime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
241
 
@@ -244,19 +254,24 @@ class RankingApp:
244
  winner_modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_ids[st.session_state.pointer[prompt_id][winner]]]['modelVersion_id'].values[0]
245
  loser_modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_ids[st.session_state.pointer[prompt_id][loser]]]['modelVersion_id'].values[0]
246
 
247
- # remove the old battle result if exists
248
- query = "DELETE FROM battle_results WHERE username = %s AND timestamp = %s AND prompt_id = %s AND winner = %s AND loser = %s"
249
- curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, winner_modelVersion_id, loser_modelVersion_id))
250
- curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, loser_modelVersion_id, winner_modelVersion_id))
251
 
252
  # insert the battle result into the database
253
- query = "INSERT INTO battle_results (username, timestamp, tag, prompt_id, winner, loser, battletime) VALUES (%s, %s, %s, %s, %s, %s, %s)"
254
- curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], self.promptBook[self.promptBook['prompt_id'] == prompt_id]['tag'].values[0], prompt_id, winner_modelVersion_id, loser_modelVersion_id, battletime))
255
 
256
  curser.close()
257
  RANKING_CONN.commit()
258
 
259
  if curr_position == total_num - 1:
 
 
 
 
 
260
  st.session_state.progress[prompt_id] = 'finished'
261
 
262
  # drop 'modelVersion_standings' from session state if exists
@@ -266,8 +281,8 @@ class RankingApp:
266
  else:
267
  st.session_state.pointer[prompt_id][loser] = curr_position + 1
268
 
269
- def battle_mode(self, items, prompt_id):
270
- self.battle_images(items, prompt_id)
271
 
272
  def app(self):
273
  st.write('### Generative Model Ranking')
@@ -288,15 +303,43 @@ class RankingApp:
288
  'tag'] in prompt_tags else 0
289
  print(tag_idx)
290
 
291
- tag = st.radio('Select a tag', prompt_tags, index=tag_idx, horizontal=True, label_visibility='collapsed')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
 
 
293
  prompts = np.sort(items['prompt'].unique())[::-1].tolist()
294
 
295
  prompt_idx = prompts.index(st.session_state.gallery_focus['prompt']) if st.session_state.gallery_focus[
296
  'prompt'] in prompts else 0
297
  print(prompt_idx)
298
 
299
- selected_prompt = st.selectbox('Select a prompt', prompts, index=prompt_idx, label_visibility='collapsed')
 
 
 
 
 
 
 
 
 
 
 
300
 
301
  # mode = st.radio('Select a mode', ['Drag and Sort', 'Battle'], index=1)
302
  mode = st.session_state.assigned_rank_mode
@@ -317,13 +360,15 @@ class RankingApp:
317
  st.session_state.progress[prompt_id] = 'ranking'
318
 
319
  if st.session_state.progress[prompt_id] == 'ranking':
 
320
  st.caption("We might pair some other images that you haven't selected based on our evaluation matrix.")
321
  if mode == 'Drag and Sort':
322
- self.dragsort_mode(items, prompt_id, batch_num)
323
  elif mode == 'Battle':
324
- self.battle_mode(items, prompt_id)
325
 
326
  elif st.session_state.progress[prompt_id] == 'finished':
 
327
  # st.toast('**Summary is available now!**')
328
  # st.write('---')
329
  with st.form(key='ranking_finished'):
@@ -348,25 +393,45 @@ class RankingApp:
348
  st.form_submit_button('👆 Rank other prompts', use_container_width=True)
349
 
350
  with options_panel[3]:
351
- restart_btn = st.form_submit_button('🎖️ Re-rank this prompt', use_container_width=True)
352
- if restart_btn:
353
- st.session_state.progress[prompt_id] = 'ranking'
354
- st.session_state.counter[prompt_id] = 0
355
- st.session_state.pointer[prompt_id] = {'left': 0, 'right': 1}
356
- st.experimental_rerun()
357
-
358
- def connect_to_db():
359
- conn = pymysql.connect(
360
- host=os.environ.get('RANKING_DB_HOST'),
361
- port=3306,
362
- database='myRanking',
363
- user=os.environ.get('RANKING_DB_USER'),
364
- password=os.environ.get('RANKING_DB_PASSWORD'),
365
- charset='utf8mb4',
366
- cursorclass=pymysql.cursors.DictCursor
367
- )
368
-
369
- return conn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
 
372
  if __name__ == "__main__":
 
11
  from streamlit_extras.switch_page_button import switch_page
12
 
13
  from pages.Gallery import load_hf_dataset
14
+ from Home import connect_to_db
15
 
16
 
17
  class RankingApp:
 
123
  for k in st.session_state.ranking[prompt_id][batch_idx].keys():
124
  st.session_state.ranking[prompt_id][batch_idx][k] = sorted_list.index(k)
125
 
126
+ def dragsort_mode(self, tag, items, prompt_id, batch_num):
127
  st.session_state.counter[prompt_id] = 0 if prompt_id not in st.session_state.counter else \
128
  st.session_state.counter[prompt_id]
129
 
 
150
  with control:
151
  if st.session_state.counter[prompt_id] < batch_num - 1:
152
  st.button(":arrow_right:", key='next', on_click=self.next_batch, help='Next Batch',
153
+ kwargs={'tag': tag, 'prompt_id': prompt_id}, use_container_width=True, type='primary')
154
  else:
155
  st.button(":heavy_check_mark:", key='finished', on_click=self.next_batch, help='Finished',
156
+ kwargs={'tag': tag, 'prompt_id': prompt_id, 'progress': 'finished'}, use_container_width=True, type='primary')
157
 
158
  if st.session_state.counter[prompt_id] > 0:
159
  st.button(":arrow_left:", key='prev', on_click=self.prev_batch, help='Previous Batch',
160
  kwargs={'prompt_id': prompt_id}, use_container_width=True)
161
 
162
+ def next_batch(self, tag, prompt_id, progress=None):
163
 
164
  curser = RANKING_CONN.cursor()
165
 
 
169
  modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_id]['modelVersion_id'].values[0]
170
  position_version_dict[position] = modelVersion_id
171
 
172
+ # # get all records of this user and prompt
173
+ # query = "SELECT * FROM sort_results WHERE username = %s AND timestamp = %s AND prompt_id = %s"
174
+ # curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id))
175
+ # results = curser.fetchall()
176
+ # print(results)
177
+ #
178
+ # # remove the old ranking with the same modelVersion_id if exists
179
+ # for result in results:
180
+ # prev_ids = [result['position1'], result['position2'], result['position3'], result['position4']]
181
+ # curr_ids = [position_version_dict[0], position_version_dict[1], position_version_dict[2], position_version_dict[3]]
182
+ # if len(set(prev_ids).intersection(set(curr_ids))) == 4:
183
+ # query = "DELETE FROM sort_results WHERE username = %s AND timestamp = %s AND prompt_id = %s AND position1 = %s AND position2 = %s AND position3 = %s AND position4 = %s"
184
+ # curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, result['position1'], result['position2'], result['position3'], result['position4']))
185
+
186
+ sorttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
187
+ # handle the case where user press the 'prev' button
188
+ query = "DELETE FROM sort_results WHERE username = %s AND timestamp = %s AND prompt_id = %s AND epoch = %s"
189
+ curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, st.session_state.epoch['ranking'][prompt_id]))
190
+ query = "INSERT INTO sort_results (username, timestamp, tag, prompt_id, position1, position2, position3, position4, sorttime, epoch) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)"
191
+ curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], self.promptBook[self.promptBook['prompt_id'] == prompt_id]['tag'].values[0], prompt_id, position_version_dict[0], position_version_dict[1], position_version_dict[2], position_version_dict[3], sorttime, st.session_state.epoch['ranking'][prompt_id]))
192
 
193
  curser.close()
194
  RANKING_CONN.commit()
195
 
196
  if progress == 'finished':
197
+ st.session_state.epoch['ranking'][prompt_id] += 1
198
+
199
+ st.session_state.epoch['summary'][tag] = st.session_state.epoch['summary'].get(tag, 0) + 1
200
+ st.session_state.epoch['summary']['overall'] += 1
201
+
202
  st.session_state.progress[prompt_id] = 'finished'
203
  # drop 'modelVersion_standings' from session state if exists
204
  st.session_state.pop('modelVersion_standings', None)
 
208
  def prev_batch(self, prompt_id):
209
  st.session_state.counter[prompt_id] -= 1
210
 
211
+ def battle_images(self, tag, items, prompt_id):
212
  if 'pointer' not in st.session_state:
213
  st.session_state.pointer = {}
214
 
 
231
  # total_score = items['total_score'][st.session_state.pointer[prompt_id]['left']]
232
  # st.write(f'Total Score: {total_score}')
233
 
234
+ btn_left = st.button('Left is better', key='left', on_click=self.next_battle, kwargs={'tag': tag, 'prompt_id': prompt_id, 'image_ids': items['image_id'], 'winner': 'left', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
235
  st.image(img_url, use_column_width=True)
236
 
237
  with right:
 
242
  # total_score = items['total_score'][st.session_state.pointer[prompt_id]['right']]
243
  # st.write(f'Total Score: {total_score}')
244
 
245
+ btn_right = st.button('Right is better', key='right', on_click=self.next_battle, kwargs={'tag': tag, 'prompt_id': prompt_id, 'image_ids': items['image_id'], 'winner': 'right', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
246
  st.image(img_url, use_column_width=True)
247
 
248
+ def next_battle(self, tag, prompt_id, image_ids, winner, curr_position, total_num):
249
  loser = 'left' if winner == 'right' else 'right'
250
  battletime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
251
 
 
254
  winner_modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_ids[st.session_state.pointer[prompt_id][winner]]]['modelVersion_id'].values[0]
255
  loser_modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_ids[st.session_state.pointer[prompt_id][loser]]]['modelVersion_id'].values[0]
256
 
257
+ # # remove the old battle result if exists
258
+ # query = "DELETE FROM battle_results WHERE username = %s AND timestamp = %s AND prompt_id = %s AND winner = %s AND loser = %s"
259
+ # curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, winner_modelVersion_id, loser_modelVersion_id))
260
+ # curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, loser_modelVersion_id, winner_modelVersion_id))
261
 
262
  # insert the battle result into the database
263
+ query = "INSERT INTO battle_results (username, timestamp, tag, prompt_id, winner, loser, battletime, epoch) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)"
264
+ curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], self.promptBook[self.promptBook['prompt_id'] == prompt_id]['tag'].values[0], prompt_id, winner_modelVersion_id, loser_modelVersion_id, battletime, st.session_state.epoch['ranking'][prompt_id]))
265
 
266
  curser.close()
267
  RANKING_CONN.commit()
268
 
269
  if curr_position == total_num - 1:
270
+ st.session_state.epoch['ranking'][prompt_id] += 1
271
+
272
+ st.session_state.epoch['summary'][tag] = st.session_state.epoch['summary'].get(tag, 0) + 1
273
+ st.session_state.epoch['summary']['overall'] += 1
274
+
275
  st.session_state.progress[prompt_id] = 'finished'
276
 
277
  # drop 'modelVersion_standings' from session state if exists
 
281
  else:
282
  st.session_state.pointer[prompt_id][loser] = curr_position + 1
283
 
284
+ def battle_mode(self, tag, items, prompt_id):
285
+ self.battle_images(tag, items, prompt_id)
286
 
287
  def app(self):
288
  st.write('### Generative Model Ranking')
 
303
  'tag'] in prompt_tags else 0
304
  print(tag_idx)
305
 
306
+ # color the finished tags
307
+ finished_tags = []
308
+ for tag in prompt_tags:
309
+ append_tag = True
310
+ for prompt_id in self.promptBook[self.promptBook['tag'] == tag]['prompt_id'].unique():
311
+ if prompt_id not in st.session_state.progress or st.session_state.progress[prompt_id] != 'finished':
312
+ append_tag = False
313
+ break
314
+ if append_tag:
315
+ finished_tags.append(tag)
316
+ tag_tobe_colored = self.promptBook[self.promptBook['tag'].isin(finished_tags)]['tag'].unique().tolist()
317
+ colored_tags = self.text_coloring_add(tag_tobe_colored, prompt_tags, color_name='orange')
318
+
319
+ tag = st.radio('Select a tag', colored_tags, index=tag_idx, horizontal=True, label_visibility='collapsed')
320
+ tag = self.text_coloring_remove(tag)
321
+ print(tag)
322
  items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
323
+ # pick out prompts such that st.session_state.progress[prompt_id] == 'finished'
324
+
325
  prompts = np.sort(items['prompt'].unique())[::-1].tolist()
326
 
327
  prompt_idx = prompts.index(st.session_state.gallery_focus['prompt']) if st.session_state.gallery_focus[
328
  'prompt'] in prompts else 0
329
  print(prompt_idx)
330
 
331
+ # color the finished prompts
332
+ finished_prompts = []
333
+ for prompt_id in items['prompt_id'].unique():
334
+ if prompt_id in st.session_state.progress and st.session_state.progress[prompt_id] == 'finished':
335
+ finished_prompts.append(prompt_id)
336
+ prompt_tobe_colored = items[items['prompt_id'].isin(finished_prompts)]['prompt'].unique().tolist()
337
+ colored_prompts = self.text_coloring_add(prompt_tobe_colored, prompts, color_name='✅')
338
+
339
+ selected_prompt = st.selectbox('Select a prompt', colored_prompts, index=prompt_idx, label_visibility='collapsed')
340
+ selected_prompt = self.text_coloring_remove(selected_prompt)
341
+
342
+ st.session_state.gallery_focus = {'tag': tag, 'prompt': selected_prompt}
343
 
344
  # mode = st.radio('Select a mode', ['Drag and Sort', 'Battle'], index=1)
345
  mode = st.session_state.assigned_rank_mode
 
360
  st.session_state.progress[prompt_id] = 'ranking'
361
 
362
  if st.session_state.progress[prompt_id] == 'ranking':
363
+ st.session_state.epoch['ranking'][prompt_id] = st.session_state.epoch['ranking'].get(prompt_id, 1)
364
  st.caption("We might pair some other images that you haven't selected based on our evaluation matrix.")
365
  if mode == 'Drag and Sort':
366
+ self.dragsort_mode(tag, items, prompt_id, batch_num)
367
  elif mode == 'Battle':
368
+ self.battle_mode(tag, items, prompt_id)
369
 
370
  elif st.session_state.progress[prompt_id] == 'finished':
371
+ print(st.session_state.gallery_focus)
372
  # st.toast('**Summary is available now!**')
373
  # st.write('---')
374
  with st.form(key='ranking_finished'):
 
393
  st.form_submit_button('👆 Rank other prompts', use_container_width=True)
394
 
395
  with options_panel[3]:
396
+ restart_btn = st.form_submit_button('🎖️ Re-rank this prompt', use_container_width=True, on_click=self.rerank, kwargs={'prompt_id': prompt_id})
397
+
398
+ # with st.sidebar:
399
+ # st.write('epoch: ', st.session_state.epoch['ranking'][prompt_id])
400
+ def text_coloring_add(self, tobe_colored:list, total_items, color_name='orange'):
401
+ if color_name in ['orange', 'red', 'green', 'blue', 'violet', 'yellow']:
402
+ colored = [f':{color_name}[{item}]' if item in tobe_colored else item for item in total_items]
403
+ else:
404
+ colored = [f'[{color_name}] {item}' if item in tobe_colored else item for item in total_items]
405
+ return colored
406
+
407
+ def text_coloring_remove(self, tobe_removed):
408
+ if isinstance(tobe_removed, str):
409
+ if tobe_removed.startswith(':'):
410
+ tobe_removed = tobe_removed.split('[')[-1][:-1]
411
+
412
+ elif tobe_removed.startswith('['):
413
+ tobe_removed = tobe_removed.split(']')[-1][1:]
414
+ return tobe_removed
415
+
416
+ def rerank(self, prompt_id):
417
+ st.session_state.progress[prompt_id] = 'ranking'
418
+ if st.session_state.assigned_rank_mode == 'Drag and Sort':
419
+ st.session_state.counter[prompt_id] = 0
420
+ elif st.session_state.assigned_rank_mode == 'Battle':
421
+ st.session_state.pointer[prompt_id] = {'left': 0, 'right': 1}
422
+
423
+ # def connect_to_db():
424
+ # conn = pymysql.connect(
425
+ # host=os.environ.get('RANKING_DB_HOST'),
426
+ # port=3306,
427
+ # database='myRanking',
428
+ # user=os.environ.get('RANKING_DB_USER'),
429
+ # password=os.environ.get('RANKING_DB_PASSWORD'),
430
+ # charset='utf8mb4',
431
+ # cursorclass=pymysql.cursors.DictCursor
432
+ # )
433
+ #
434
+ # return conn
435
 
436
 
437
  if __name__ == "__main__":
pages/Summary.py CHANGED
@@ -15,7 +15,7 @@ from streamlit_extras.stylable_container import stylable_container
15
  from st_clickable_images import clickable_images
16
 
17
  from pages.Gallery import load_hf_dataset
18
- from pages.Ranking import connect_to_db
19
 
20
 
21
  class DashboardApp:
@@ -42,35 +42,52 @@ class DashboardApp:
42
  if back_to_ranking:
43
  switch_page('ranking')
44
 
45
- # with st.form('overall_feedback'):
46
- # feedback = st.text_area('Please leave your comments here.', key='comment')
47
- # submit_feedback = st.form_submit_button('Submit Feedback')
48
- # if submit_feedback:
49
- # print(feedback)
 
 
 
 
 
 
50
 
51
- # return tag
52
 
53
  def leaderboard(self, tag, db_table):
54
  tag = '%' if tag == 'overview' else tag
55
 
56
- # get the ranking results of the current user
 
 
57
  curser = RANKING_CONN.cursor()
58
- curser.execute(f"SELECT * FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND tag LIKE '{tag}'")
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  results = curser.fetchall()
60
  curser.close()
61
 
 
 
62
  if tag not in st.session_state.modelVersion_standings:
63
  st.session_state.modelVersion_standings[tag] = self.score_calculator(results, db_table)
64
 
65
  # sort the modelVersion_standings by value into a list of tuples in descending order
66
  st.session_state.modelVersion_standings[tag] = sorted(st.session_state.modelVersion_standings[tag].items(), key=lambda x: x[1], reverse=True)
67
-
68
- # tab1, tab2 = st.tabs(['Top Picks', 'Detailed Info'])
69
-
70
- # with tab1:
71
- # self.podium(modelVersion_standings)
72
- # switch_stage = st.toggle('Manual Reorder', key='switch_stage')
73
-
74
  example_prompts = []
75
  # get example images
76
  for key, value in st.session_state.selected_dict.items():
@@ -78,19 +95,8 @@ class DashboardApp:
78
  if model[0] in value:
79
  example_prompts.append(key)
80
 
81
- # if switch_stage:
82
- # self.podium_expander(tag, n=len(st.session_state.modelVersion_standings[tag]), summary_mode='edit', example_prompts=example_prompts)
83
- # else:
84
  self.podium_expander(tag, n=len(st.session_state.modelVersion_standings[tag]), summary_mode='display', example_prompts=example_prompts)
85
- # if st.session_state.summary_mode == 'display':
86
- # switch_stage = st.button('Manual Reorder', key='switch_stage_edit', on_click=lambda: st.session_state.__setitem__('summary_mode', 'edit'))
87
- # self.podium_expander(tag, n=3, summary_mode='display')
88
- #
89
- # elif st.session_state.summary_mode == 'edit':
90
- # switch_stage = st.button('Done', key='switch_stage_done', type='primary', on_click=lambda: st.session_state.__setitem__('summary_mode', 'display'))
91
- # self.podium_expander(tag, n=len(st.session_state.modelVersion_standings[tag]), summary_mode='edit')
92
 
93
- # with tab2:
94
  st.write('---')
95
  st.write('**Detailed information of all selected models**')
96
  detailed_info = pd.merge(pd.DataFrame(st.session_state.modelVersion_standings[tag], columns=['modelVersion_id', 'ranking_score']), self.roster, on='modelVersion_id')
@@ -101,6 +107,7 @@ class DashboardApp:
101
  st.caption('You can click the header to sort the table by that column.')
102
 
103
  def podium_expander(self, tag, example_prompts, n=3, summary_mode: ['display', 'edit'] = 'display'):
 
104
 
105
  for i in range(n):
106
  modelVersion_id = st.session_state.modelVersion_standings[tag][i][0]
@@ -137,7 +144,7 @@ class DashboardApp:
137
  example_images = [f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{image}.png" for image in example_images]
138
  clickable_images(
139
  example_images,
140
- img_style={"margin": "5px", "height": "120px"}
141
  )
142
 
143
  else:
@@ -173,19 +180,33 @@ class DashboardApp:
173
  if i != n - 1:
174
  st.write('---')
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  def switch_order(self, tag, current, target):
177
- # st.session_state.modelVersion_standings[tag][current], st.session_state.modelVersion_standings[tag][target] = st.session_state.modelVersion_standings[tag][target], st.session_state.modelVersion_standings[tag][current]
178
  # insert the current before the target
179
  st.session_state.modelVersion_standings[tag].insert(target, st.session_state.modelVersion_standings[tag].pop(current))
180
- # print(st.session_state.modelVersion_standings[tag])
181
- # curser = RANKING_CONN.cursor()
182
- # RANKING_CONN.commit()
183
- # curser.close()
184
  curser = RANKING_CONN.cursor()
185
  # clear the current user's ranking results
186
- curser.execute(f"DELETE FROM summary_results WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND tag = '{tag}'")
187
  for i in range(len(st.session_state.modelVersion_standings[tag])):
188
- curser.execute(f"INSERT INTO summary_results (username, timestamp, tag, modelVersion_id, position, ranking_score) VALUES ('{st.session_state.user_id[0]}', '{st.session_state.user_id[1]}', '{tag}', '{st.session_state.modelVersion_standings[tag][i][0]}', {i+1}, {st.session_state.modelVersion_standings[tag][i][1]})")
189
  RANKING_CONN.commit()
190
  curser.close()
191
 
@@ -197,7 +218,6 @@ class DashboardApp:
197
 
198
  for record in results:
199
  modelVersion_standings[record['winner']] = modelVersion_standings.get(record['winner'], 0) + 1
200
-
201
  # add the loser who never wins
202
  if record['loser'] not in modelVersion_standings:
203
  modelVersion_standings[record['loser']] = 0
@@ -221,7 +241,7 @@ class DashboardApp:
221
  # get tags from database of the current user
222
  db_table = 'sort_results' if mode == 'Drag and Sort' else 'battle_results'
223
 
224
- tags = ['overview']
225
  curser = RANKING_CONN.cursor()
226
  curser.execute(
227
  f"SELECT DISTINCT tag FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}'")
@@ -229,29 +249,30 @@ class DashboardApp:
229
  tags.append(row['tag'])
230
  curser.close()
231
 
232
- if tags == ['overview']:
233
  st.info(f'No rankings are finished with {mode} mode yet.')
234
 
235
  else:
236
- tags = tags[1:2] if len(tags) == 2 else tags
 
237
  tag = st.radio('Select a tag', tags, index=0, horizontal=True, label_visibility='collapsed')
238
  self.sidebar(tags, mode)
239
  self.leaderboard(tag, db_table)
240
 
241
- with st.sidebar:
242
- with st.form('overall_feedback'):
243
- comment = st.text_area('Please leave your comments here.', key='comment')
244
- submit_feedback = st.form_submit_button('Submit Feedback')
245
- if submit_feedback:
246
- commenttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
247
- curser = RANKING_CONN.cursor()
248
- # parse the comment to at most 300 to avoid SQL injection
249
- for i in range(0, len(comment), 300):
250
- curser.execute(f"INSERT INTO comments (username, timestamp, comment, commenttime) VALUES ('{st.session_state.user_id[0]}', '{st.session_state.user_id[1]}', '{comment[i:i+300]}', '{commenttime}')")
251
- RANKING_CONN.commit()
252
- curser.close()
253
-
254
- st.toast('🙏 **Thanks for your feedback! We will take it into consideration in our future work.**')
255
 
256
 
257
  if __name__ == "__main__":
 
15
  from st_clickable_images import clickable_images
16
 
17
  from pages.Gallery import load_hf_dataset
18
+ from Home import connect_to_db
19
 
20
 
21
  class DashboardApp:
 
42
  if back_to_ranking:
43
  switch_page('ranking')
44
 
45
+ with st.form('overall_feedback'):
46
+ comment = st.text_area('Please leave your comments here.', key='comment')
47
+ submit_feedback = st.form_submit_button('Submit Feedback')
48
+ if submit_feedback:
49
+ commenttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
50
+ curser = RANKING_CONN.cursor()
51
+ # parse the comment to at most 300 to avoid SQL injection
52
+ for i in range(0, len(comment), 300):
53
+ curser.execute(f"INSERT INTO comments (username, timestamp, comment, commenttime) VALUES ('{st.session_state.user_id[0]}', '{st.session_state.user_id[1]}', '{comment[i:i+300]}', '{commenttime}')")
54
+ RANKING_CONN.commit()
55
+ curser.close()
56
 
57
+ st.sidebar.info('🙏 **Thanks for your feedback! We will take it into consideration in our future work.**')
58
 
59
  def leaderboard(self, tag, db_table):
60
  tag = '%' if tag == 'overview' else tag
61
 
62
+ # print('tag', tag)
63
+
64
+ # get the ranking results of the current user with the lastest epoch
65
  curser = RANKING_CONN.cursor()
66
+ # curser.execute(f"SELECT * FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND tag LIKE '{tag}'")
67
+ # curser.execute(f"SELECT * FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND tag LIKE '{tag}' ORDER BY epoch DESC LIMIT 1")
68
+ curser.execute(
69
+ f"SELECT * FROM {db_table}\
70
+ WHERE username = '{st.session_state.user_id[0]}'\
71
+ AND timestamp = '{st.session_state.user_id[1]}'\
72
+ AND tag LIKE '{tag}'\
73
+ AND epoch =\
74
+ (SELECT MAX(epoch) FROM {db_table}\
75
+ WHERE username = '{st.session_state.user_id[0]}'\
76
+ AND timestamp = '{st.session_state.user_id[1]}'\
77
+ AND tag LIKE '{tag}')")
78
+
79
+
80
  results = curser.fetchall()
81
  curser.close()
82
 
83
+ # print('results', results, len(results))
84
+
85
  if tag not in st.session_state.modelVersion_standings:
86
  st.session_state.modelVersion_standings[tag] = self.score_calculator(results, db_table)
87
 
88
  # sort the modelVersion_standings by value into a list of tuples in descending order
89
  st.session_state.modelVersion_standings[tag] = sorted(st.session_state.modelVersion_standings[tag].items(), key=lambda x: x[1], reverse=True)
90
+ print(st.session_state.modelVersion_standings[tag])
 
 
 
 
 
 
91
  example_prompts = []
92
  # get example images
93
  for key, value in st.session_state.selected_dict.items():
 
95
  if model[0] in value:
96
  example_prompts.append(key)
97
 
 
 
 
98
  self.podium_expander(tag, n=len(st.session_state.modelVersion_standings[tag]), summary_mode='display', example_prompts=example_prompts)
 
 
 
 
 
 
 
99
 
 
100
  st.write('---')
101
  st.write('**Detailed information of all selected models**')
102
  detailed_info = pd.merge(pd.DataFrame(st.session_state.modelVersion_standings[tag], columns=['modelVersion_id', 'ranking_score']), self.roster, on='modelVersion_id')
 
107
  st.caption('You can click the header to sort the table by that column.')
108
 
109
  def podium_expander(self, tag, example_prompts, n=3, summary_mode: ['display', 'edit'] = 'display'):
110
+ self.save_summary(tag)
111
 
112
  for i in range(n):
113
  modelVersion_id = st.session_state.modelVersion_standings[tag][i][0]
 
144
  example_images = [f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{image}.png" for image in example_images]
145
  clickable_images(
146
  example_images,
147
+ img_style={"margin": "5px", "height": "120px"},
148
  )
149
 
150
  else:
 
180
  if i != n - 1:
181
  st.write('---')
182
 
183
+ def save_summary(self, tag):
184
+ # get the lastest summary_results epoch of the current user
185
+ tag_name = 'overview' if tag == '%' else tag
186
+ curser = RANKING_CONN.cursor()
187
+ curser.execute(f"SELECT epoch FROM summary_results WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND tag = '{tag_name}' ORDER BY epoch DESC LIMIT 1")
188
+ latest_epoch = curser.fetchone()
189
+ curser.close()
190
+ # print('latest_epoch',latest_epoch)
191
+ if latest_epoch is None or latest_epoch['epoch'] < st.session_state.epoch['summary'][tag_name]:
192
+ # save the current ranking results to the database
193
+ summarytime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
194
+ curser = RANKING_CONN.cursor()
195
+ for i in range(len(st.session_state.modelVersion_standings[tag])):
196
+ curser.execute(f"INSERT INTO summary_results (username, timestamp, tag, modelVersion_id, position, ranking_score, summarytime, epoch, customized) VALUES ('{st.session_state.user_id[0]}', '{st.session_state.user_id[1]}', '{tag_name}', '{st.session_state.modelVersion_standings[tag][i][0]}', {i+1}, {st.session_state.modelVersion_standings[tag][i][1]}, '{summarytime}', {st.session_state.epoch['summary'][tag_name]}, 0)")
197
+ RANKING_CONN.commit()
198
+ curser.close()
199
+
200
  def switch_order(self, tag, current, target):
 
201
  # insert the current before the target
202
  st.session_state.modelVersion_standings[tag].insert(target, st.session_state.modelVersion_standings[tag].pop(current))
203
+ tag_name = 'overview' if tag == '%' else tag
204
+ summarytime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
 
 
205
  curser = RANKING_CONN.cursor()
206
  # clear the current user's ranking results
207
+ curser.execute(f"DELETE FROM summary_results WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND tag = '{tag_name}' AND epoch = {st.session_state.epoch['summary'][tag_name]}")
208
  for i in range(len(st.session_state.modelVersion_standings[tag])):
209
+ curser.execute(f"INSERT INTO summary_results (username, timestamp, tag, modelVersion_id, position, ranking_score, summarytime, epoch, customized) VALUES ('{st.session_state.user_id[0]}', '{st.session_state.user_id[1]}', '{tag_name}', '{st.session_state.modelVersion_standings[tag][i][0]}', {i+1}, {st.session_state.modelVersion_standings[tag][i][1]}, '{summarytime}', {st.session_state.epoch['summary'][tag_name]}, 1)")
210
  RANKING_CONN.commit()
211
  curser.close()
212
 
 
218
 
219
  for record in results:
220
  modelVersion_standings[record['winner']] = modelVersion_standings.get(record['winner'], 0) + 1
 
221
  # add the loser who never wins
222
  if record['loser'] not in modelVersion_standings:
223
  modelVersion_standings[record['loser']] = 0
 
241
  # get tags from database of the current user
242
  db_table = 'sort_results' if mode == 'Drag and Sort' else 'battle_results'
243
 
244
+ tags = []
245
  curser = RANKING_CONN.cursor()
246
  curser.execute(
247
  f"SELECT DISTINCT tag FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}'")
 
249
  tags.append(row['tag'])
250
  curser.close()
251
 
252
+ if len(tags) == 0:
253
  st.info(f'No rankings are finished with {mode} mode yet.')
254
 
255
  else:
256
+ # tags = tags[1:2] if len(tags) == 2 else tags
257
+ tag = ['overview'] + tags if len(tags) > 1 else tags
258
  tag = st.radio('Select a tag', tags, index=0, horizontal=True, label_visibility='collapsed')
259
  self.sidebar(tags, mode)
260
  self.leaderboard(tag, db_table)
261
 
262
+ # with st.sidebar:
263
+ # with st.form('overall_feedback'):
264
+ # comment = st.text_area('Please leave your comments here.', key='comment')
265
+ # submit_feedback = st.form_submit_button('Submit Feedback')
266
+ # if submit_feedback:
267
+ # commenttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
268
+ # curser = RANKING_CONN.cursor()
269
+ # # parse the comment to at most 300 to avoid SQL injection
270
+ # for i in range(0, len(comment), 300):
271
+ # curser.execute(f"INSERT INTO comments (username, timestamp, comment, commenttime) VALUES ('{st.session_state.user_id[0]}', '{st.session_state.user_id[1]}', '{comment[i:i+300]}', '{commenttime}')")
272
+ # RANKING_CONN.commit()
273
+ # curser.close()
274
+ #
275
+ # st.sidebar.info('🙏 **Thanks for your feedback! We will take it into consideration in our future work.**')
276
 
277
 
278
  if __name__ == "__main__":