Spaces:
Running
Running
add two-substage gallery page
Browse files- Home.py +2 -0
- pages/Gallery.py +104 -45
- pages/Summary.py +3 -3
Home.py
CHANGED
@@ -36,6 +36,8 @@ def logout():
|
|
36 |
st.session_state.pop('user_id', None)
|
37 |
st.session_state.pop('selected_dict', None)
|
38 |
st.session_state.pop('score_weights', None)
|
|
|
|
|
39 |
|
40 |
|
41 |
def info():
|
|
|
36 |
st.session_state.pop('user_id', None)
|
37 |
st.session_state.pop('selected_dict', None)
|
38 |
st.session_state.pop('score_weights', None)
|
39 |
+
st.session_state.pop('gallery_state', None)
|
40 |
+
st.session_state.pop('progress', None)
|
41 |
|
42 |
|
43 |
def info():
|
pages/Gallery.py
CHANGED
@@ -24,6 +24,14 @@ class GalleryApp:
|
|
24 |
self.promptBook = promptBook
|
25 |
self.images_ds = images_ds
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
def gallery_standard(self, items, col_num, info):
|
28 |
rows = len(items) // col_num + 1
|
29 |
containers = [st.container() for _ in range(rows)]
|
@@ -276,6 +284,7 @@ class GalleryApp:
|
|
276 |
# chosen_data = [stx.TabBarItemData(id=tag, title=tag, description='') for tag in prompt_tags]
|
277 |
# tag = stx.tab_bar(chosen_data, key='tag', default='food')
|
278 |
|
|
|
279 |
tag = st.radio('Select a tag', prompt_tags, index=5, horizontal=True, key='tag')
|
280 |
|
281 |
# tabs = st.tabs(prompt_tags)
|
@@ -284,23 +293,34 @@ class GalleryApp:
|
|
284 |
# tag = prompt_tags[i]
|
285 |
items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
|
286 |
|
287 |
-
prompts = np.sort(items['prompt'].unique())[::1]
|
288 |
|
|
|
289 |
subset_selector = st.columns([3, 1])
|
290 |
with subset_selector[0]:
|
|
|
|
|
|
|
|
|
291 |
# selected_prompt = st.selectbox('Select prompt', prompts, index=3)
|
292 |
-
selected_prompt = selectbox('Select prompt', prompts, key=f'prompt_{tag}', no_selection_label='---')
|
293 |
-
|
294 |
-
subset = st.selectbox('Select a subset', ['All', 'Selected Only'], index=0, key=f'subset_{tag}')
|
295 |
|
296 |
if selected_prompt is None:
|
297 |
-
st.markdown(':orange[Please select a prompt above👆]')
|
298 |
st.write('**Feel free to navigate among tags and pages! Your selection will be saved within one log-in session.**')
|
|
|
|
|
|
|
|
|
299 |
else:
|
300 |
items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
|
301 |
prompt_id = items['prompt_id'].unique()[0]
|
302 |
note = items['note'].unique()[0]
|
303 |
-
|
|
|
|
|
|
|
304 |
|
305 |
# add safety check for some prompts
|
306 |
safety_check = True
|
@@ -316,16 +336,61 @@ class GalleryApp:
|
|
316 |
st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
|
317 |
safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
|
318 |
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
#
|
326 |
-
#
|
327 |
-
|
|
|
|
|
328 |
self.graph_mode(prompt_id, items)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
|
330 |
try:
|
331 |
self.sidebar(items, prompt_id, note)
|
@@ -390,8 +455,6 @@ class GalleryApp:
|
|
390 |
infos_df = infos_df.rename(index={'model_name': 'Model', 'modelVersion_name': 'Version', 'model_download_count': 'Downloads', 'clip_score': 'Clip Score', 'mcos_score': 'mcos Score', 'nsfw_score': 'NSFW Score'})
|
391 |
st.table(infos_df)
|
392 |
|
393 |
-
st.button('🎖️ Proceed selections to ranking', on_click=switch_page, args=("ranking",), use_container_width=True,)
|
394 |
-
|
395 |
# for info in infos:
|
396 |
# st.write(f"**{info}**:")
|
397 |
# st.write(item[info])
|
@@ -399,8 +462,6 @@ class GalleryApp:
|
|
399 |
else:
|
400 |
st.info('Please click on an image to show')
|
401 |
|
402 |
-
|
403 |
-
|
404 |
def gallery_mode(self, prompt_id, items):
|
405 |
items, info, col_num = self.selection_panel(items)
|
406 |
|
@@ -423,33 +484,31 @@ class GalleryApp:
|
|
423 |
# if prompt:
|
424 |
# switch_page("ranking")
|
425 |
|
426 |
-
with st.form(key=f'{prompt_id}'):
|
427 |
# buttons = st.columns([1, 1, 1])
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
|
449 |
# st.info("Don't forget to scroll back to top and click the 'Confirm Selection' button to save your selection!!!")
|
450 |
|
451 |
-
|
452 |
-
|
453 |
def submit_actions(self, status, prompt_id):
|
454 |
# remove counter from session state
|
455 |
# st.session_state.pop('counter', None)
|
@@ -473,7 +532,7 @@ class GalleryApp:
|
|
473 |
st.session_state.selected_dict[prompt_id].append(int(keys[2]))
|
474 |
# switch_page("ranking")
|
475 |
print(st.session_state.selected_dict, 'continue')
|
476 |
-
st.experimental_rerun()
|
477 |
|
478 |
def dynamic_weight(self, prompt_id, items, method='Grid Search'):
|
479 |
selected = items[
|
@@ -656,9 +715,9 @@ if __name__ == "__main__":
|
|
656 |
roster, promptBook, images_ds = load_hf_dataset()
|
657 |
# print(promptBook.columns)
|
658 |
|
659 |
-
# initialize selected_dict
|
660 |
-
if 'selected_dict' not in st.session_state:
|
661 |
-
|
662 |
|
663 |
app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
|
664 |
app.app()
|
|
|
24 |
self.promptBook = promptBook
|
25 |
self.images_ds = images_ds
|
26 |
|
27 |
+
# init gallery state
|
28 |
+
if 'gallery_state' not in st.session_state:
|
29 |
+
st.session_state.gallery_state = {}
|
30 |
+
|
31 |
+
# initialize selected_dict
|
32 |
+
if 'selected_dict' not in st.session_state:
|
33 |
+
st.session_state['selected_dict'] = {}
|
34 |
+
|
35 |
def gallery_standard(self, items, col_num, info):
|
36 |
rows = len(items) // col_num + 1
|
37 |
containers = [st.container() for _ in range(rows)]
|
|
|
284 |
# chosen_data = [stx.TabBarItemData(id=tag, title=tag, description='') for tag in prompt_tags]
|
285 |
# tag = stx.tab_bar(chosen_data, key='tag', default='food')
|
286 |
|
287 |
+
# save tag to session state on change
|
288 |
tag = st.radio('Select a tag', prompt_tags, index=5, horizontal=True, key='tag')
|
289 |
|
290 |
# tabs = st.tabs(prompt_tags)
|
|
|
293 |
# tag = prompt_tags[i]
|
294 |
items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
|
295 |
|
296 |
+
prompts = np.sort(items['prompt'].unique())[::1].tolist()
|
297 |
|
298 |
+
st.caption('Select a prompt')
|
299 |
subset_selector = st.columns([3, 1])
|
300 |
with subset_selector[0]:
|
301 |
+
# remember last prompt
|
302 |
+
# if 'prompt_idx_last_time' not in st.session_state:
|
303 |
+
# st.session_state.prompt_idx_last_time = 0
|
304 |
+
|
305 |
# selected_prompt = st.selectbox('Select prompt', prompts, index=3)
|
306 |
+
selected_prompt = selectbox('Select prompt', prompts, key=f'prompt_{tag}', no_selection_label='---', label_visibility='collapsed', index=0)
|
307 |
+
# st.session_state.prompt_idx_last_time = prompts.index(selected_prompt) if selected_prompt else 0
|
|
|
308 |
|
309 |
if selected_prompt is None:
|
310 |
+
# st.markdown(':orange[Please select a prompt above👆]')
|
311 |
st.write('**Feel free to navigate among tags and pages! Your selection will be saved within one log-in session.**')
|
312 |
+
|
313 |
+
with subset_selector[1]:
|
314 |
+
st.write(':orange[👈 **Please select a prompt**]')
|
315 |
+
|
316 |
else:
|
317 |
items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
|
318 |
prompt_id = items['prompt_id'].unique()[0]
|
319 |
note = items['note'].unique()[0]
|
320 |
+
|
321 |
+
# add state to session state
|
322 |
+
if prompt_id not in st.session_state.gallery_state:
|
323 |
+
st.session_state.gallery_state[prompt_id] = 'graph'
|
324 |
|
325 |
# add safety check for some prompts
|
326 |
safety_check = True
|
|
|
336 |
st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
|
337 |
safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
|
338 |
|
339 |
+
print('current state: ', st.session_state.gallery_state[prompt_id])
|
340 |
+
|
341 |
+
if st.session_state.gallery_state[prompt_id] == 'graph':
|
342 |
+
if safety_check:
|
343 |
+
# if subset == 'Selected Only' and 'selected_dict' in st.session_state:
|
344 |
+
# # try:
|
345 |
+
# items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
|
346 |
+
# self.gallery_mode(prompt_id, items)
|
347 |
+
# # except:
|
348 |
+
# # st.warning('No selected images found')
|
349 |
+
# else:
|
350 |
self.graph_mode(prompt_id, items)
|
351 |
+
with subset_selector[1]:
|
352 |
+
# if st.session_state.gallery_state[prompt_id] == 'graph':
|
353 |
+
# subset = st.selectbox('Select a subset', ['All', 'Selected Only'], index=0, key=f'subset_{tag}')
|
354 |
+
has_selection = False
|
355 |
+
try:
|
356 |
+
if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
|
357 |
+
has_selection = True
|
358 |
+
except:
|
359 |
+
pass
|
360 |
+
|
361 |
+
if has_selection:
|
362 |
+
checkout = st.button('Check out selections', use_container_width=True, type='primary')
|
363 |
+
if checkout:
|
364 |
+
print('checkout')
|
365 |
+
|
366 |
+
st.session_state.gallery_state[prompt_id] = 'gallery'
|
367 |
+
print(st.session_state.gallery_state[prompt_id])
|
368 |
+
st.experimental_rerun()
|
369 |
+
else:
|
370 |
+
st.write('Select images you like below 👇')
|
371 |
+
|
372 |
+
elif st.session_state.gallery_state[prompt_id] == 'gallery':
|
373 |
+
items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(
|
374 |
+
drop=True)
|
375 |
+
self.gallery_mode(prompt_id, items)
|
376 |
+
|
377 |
+
with subset_selector[1]:
|
378 |
+
state_operations = st.columns([1, 1])
|
379 |
+
with state_operations[0]:
|
380 |
+
back = st.button('Back', use_container_width=True)
|
381 |
+
if back:
|
382 |
+
st.session_state.gallery_state[prompt_id] = 'graph'
|
383 |
+
st.experimental_rerun()
|
384 |
+
|
385 |
+
with state_operations[1]:
|
386 |
+
forward = st.button('Check out', use_container_width=True, type='primary', on_click=self.submit_actions, args=('Continue', prompt_id))
|
387 |
+
if forward:
|
388 |
+
switch_page('ranking')
|
389 |
+
|
390 |
+
|
391 |
+
|
392 |
+
# else:
|
393 |
+
# st.button('Proceed', use_container_width=True)
|
394 |
|
395 |
try:
|
396 |
self.sidebar(items, prompt_id, note)
|
|
|
455 |
infos_df = infos_df.rename(index={'model_name': 'Model', 'modelVersion_name': 'Version', 'model_download_count': 'Downloads', 'clip_score': 'Clip Score', 'mcos_score': 'mcos Score', 'nsfw_score': 'NSFW Score'})
|
456 |
st.table(infos_df)
|
457 |
|
|
|
|
|
458 |
# for info in infos:
|
459 |
# st.write(f"**{info}**:")
|
460 |
# st.write(item[info])
|
|
|
462 |
else:
|
463 |
st.info('Please click on an image to show')
|
464 |
|
|
|
|
|
465 |
def gallery_mode(self, prompt_id, items):
|
466 |
items, info, col_num = self.selection_panel(items)
|
467 |
|
|
|
484 |
# if prompt:
|
485 |
# switch_page("ranking")
|
486 |
|
487 |
+
# with st.form(key=f'{prompt_id}'):
|
488 |
# buttons = st.columns([1, 1, 1])
|
489 |
+
# buttons_space = st.columns([1, 1, 1])
|
490 |
+
gallery_space = st.empty()
|
491 |
+
|
492 |
+
# with buttons_space[0]:
|
493 |
+
# continue_btn = st.button('Proceed selections to ranking', use_container_width=True, type='primary')
|
494 |
+
# if continue_btn:
|
495 |
+
# # self.submit_actions('Continue', prompt_id)
|
496 |
+
# switch_page("ranking")
|
497 |
+
#
|
498 |
+
# with buttons_space[1]:
|
499 |
+
# deselect_btn = st.button('Deselect All', use_container_width=True)
|
500 |
+
# if deselect_btn:
|
501 |
+
# self.submit_actions('Deselect', prompt_id)
|
502 |
+
#
|
503 |
+
# with buttons_space[2]:
|
504 |
+
# refresh_btn = st.button('Refresh', on_click=gallery_space.empty, use_container_width=True)
|
505 |
|
506 |
+
with gallery_space.container():
|
507 |
+
with st.spinner('Loading images...'):
|
508 |
+
self.gallery_standard(items, col_num, info)
|
509 |
|
510 |
# st.info("Don't forget to scroll back to top and click the 'Confirm Selection' button to save your selection!!!")
|
511 |
|
|
|
|
|
512 |
def submit_actions(self, status, prompt_id):
|
513 |
# remove counter from session state
|
514 |
# st.session_state.pop('counter', None)
|
|
|
532 |
st.session_state.selected_dict[prompt_id].append(int(keys[2]))
|
533 |
# switch_page("ranking")
|
534 |
print(st.session_state.selected_dict, 'continue')
|
535 |
+
# st.experimental_rerun()
|
536 |
|
537 |
def dynamic_weight(self, prompt_id, items, method='Grid Search'):
|
538 |
selected = items[
|
|
|
715 |
roster, promptBook, images_ds = load_hf_dataset()
|
716 |
# print(promptBook.columns)
|
717 |
|
718 |
+
# # initialize selected_dict
|
719 |
+
# if 'selected_dict' not in st.session_state:
|
720 |
+
# st.session_state['selected_dict'] = {}
|
721 |
|
722 |
app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
|
723 |
app.app()
|
pages/Summary.py
CHANGED
@@ -128,11 +128,11 @@ class DashboardApp:
|
|
128 |
st.image(image, use_column_width=True)
|
129 |
|
130 |
def score_calculator(self, results, db_table):
|
131 |
-
# sort results by battle time
|
132 |
-
results = sorted(results, key=lambda x: x['battletime'])
|
133 |
-
|
134 |
modelVersion_standings = {}
|
135 |
if db_table == 'battle_results':
|
|
|
|
|
|
|
136 |
for record in results:
|
137 |
modelVersion_standings[record['winner']] = modelVersion_standings.get(record['winner'], 0) + 1
|
138 |
|
|
|
128 |
st.image(image, use_column_width=True)
|
129 |
|
130 |
def score_calculator(self, results, db_table):
|
|
|
|
|
|
|
131 |
modelVersion_standings = {}
|
132 |
if db_table == 'battle_results':
|
133 |
+
# sort results by battle time
|
134 |
+
results = sorted(results, key=lambda x: x['battletime'])
|
135 |
+
|
136 |
for record in results:
|
137 |
modelVersion_standings[record['winner']] = modelVersion_standings.get(record['winner'], 0) + 1
|
138 |
|