Ricercar commited on
Commit
cbeb2ab
1 Parent(s): babb609

add selection indicators

Browse files
Files changed (2) hide show
  1. css/style.css +1 -1
  2. pages/Gallery.py +48 -3
css/style.css CHANGED
@@ -24,4 +24,4 @@ div.row-widget.stRadio > div[role="radiogroup"] > label[data-baseweb="radio"]:ha
24
  /*hide the circle of the radio button*/
25
  div.row-widget.stRadio > div[role="radiogroup"] > label[data-baseweb="radio"] > div:first-child {
26
  display: none;
27
- }
 
24
  /*hide the circle of the radio button*/
25
  div.row-widget.stRadio > div[role="radiogroup"] > label[data-baseweb="radio"] > div:first-child {
26
  display: none;
27
+ }
pages/Gallery.py CHANGED
@@ -34,6 +34,11 @@ class GalleryApp:
34
  if 'selected_dict' not in st.session_state:
35
  st.session_state['selected_dict'] = {}
36
 
 
 
 
 
 
37
  if 'gallery_focus' not in st.session_state:
38
  st.session_state.gallery_focus = {'tag': None, 'prompt': None}
39
 
@@ -261,6 +266,23 @@ class GalleryApp:
261
 
262
  # return prompt_tags, tag, prompt_id, items
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  def app(self):
265
  # print(st.session_state.gallery_focus)
266
  st.write('### Model Visualization and Retrieval')
@@ -279,8 +301,17 @@ class GalleryApp:
279
  else:
280
  tag_focus_idx = prompt_tags.index(st.session_state.gallery_focus['tag'])
281
 
 
 
 
 
 
282
  # save tag to session state on change
283
- tag = st.radio('Select a tag', prompt_tags, index=tag_focus_idx, horizontal=True, key='tag', label_visibility='collapsed')
 
 
 
 
284
 
285
  # print('current state: ', st.session_state.gallery_state)
286
 
@@ -299,8 +330,17 @@ class GalleryApp:
299
  # st.caption('Select a prompt')
300
  subset_selector = st.columns([3, 1])
301
  with subset_selector[0]:
302
- selected_prompt = selectbox('Select prompt', prompts, key=f'prompt_{tag}', no_selection_label='---', label_visibility='collapsed', index=prompt_focus_idx)
303
- # st.session_state.prompt_idx_last_time = prompts.index(selected_prompt) if selected_prompt else 0
 
 
 
 
 
 
 
 
 
304
 
305
  if selected_prompt is None:
306
  # st.markdown(':orange[Please select a prompt above👆]')
@@ -417,6 +457,11 @@ class GalleryApp:
417
  if select:
418
  st.session_state.selected_dict[item['prompt_id']].append(item['modelVersion_id'])
419
  self.remove_ranking_states(item['prompt_id'])
 
 
 
 
 
420
  st.experimental_rerun()
421
 
422
  # st.write(item)
 
34
  if 'selected_dict' not in st.session_state:
35
  st.session_state['selected_dict'] = {}
36
 
37
+ # clear up empty entries in seleted_dict
38
+ for prompt_id in list(st.session_state.selected_dict.keys()):
39
+ if len(st.session_state.selected_dict[prompt_id]) == 0:
40
+ st.session_state.selected_dict.pop(prompt_id)
41
+
42
  if 'gallery_focus' not in st.session_state:
43
  st.session_state.gallery_focus = {'tag': None, 'prompt': None}
44
 
 
266
 
267
  # return prompt_tags, tag, prompt_id, items
268
 
269
+ def text_coloring_add(self, tobe_colored:list, total_items, color_name='orange'):
270
+ if color_name in ['orange', 'red', 'green', 'blue', 'violet', 'yellow']:
271
+ colored = [f':{color_name}[{item}]' if item in tobe_colored else item for item in total_items]
272
+ else:
273
+ colored = [f'[{color_name}] {item}' if item in tobe_colored else item for item in total_items]
274
+ return colored
275
+
276
+ def text_coloring_remove(self, tobe_removed):
277
+ if isinstance(tobe_removed, str):
278
+ if tobe_removed.startswith(':'):
279
+ tobe_removed = tobe_removed.split('[')[-1][:-1]
280
+
281
+ elif tobe_removed.startswith('['):
282
+ tobe_removed = tobe_removed.split(']')[-1][1:]
283
+ return tobe_removed
284
+
285
+
286
  def app(self):
287
  # print(st.session_state.gallery_focus)
288
  st.write('### Model Visualization and Retrieval')
 
301
  else:
302
  tag_focus_idx = prompt_tags.index(st.session_state.gallery_focus['tag'])
303
 
304
+ # add coloring to tag based on selection
305
+ tags_tobe_colored = self.promptBook[self.promptBook['prompt_id'].isin(st.session_state.selected_dict.keys())]['tag'].unique()
306
+ # colored_prompt_tags = [f':orange[{tag}]' if tag in tags_tobe_colored else tag for tag in prompt_tags]
307
+ colored_prompt_tags = self.text_coloring_add(tags_tobe_colored, prompt_tags, color_name='orange')
308
+
309
  # save tag to session state on change
310
+ tag = st.radio('Select a tag', colored_prompt_tags, index=tag_focus_idx, horizontal=True, key='tag', label_visibility='collapsed')
311
+
312
+ # remove coloring from tag
313
+ tag = self.text_coloring_remove(tag)
314
+ print('tag: ', tag)
315
 
316
  # print('current state: ', st.session_state.gallery_state)
317
 
 
330
  # st.caption('Select a prompt')
331
  subset_selector = st.columns([3, 1])
332
  with subset_selector[0]:
333
+
334
+ # add coloring to prompt based on selection
335
+ prompts_tobe_colored = self.promptBook[self.promptBook['prompt_id'].isin(st.session_state.selected_dict.keys())]['prompt'].unique()
336
+ colored_prompts = self.text_coloring_add(prompts_tobe_colored, prompts, color_name='✅')
337
+
338
+ selected_prompt = selectbox('Select prompt', colored_prompts, key=f'prompt_{tag}', no_selection_label='---', label_visibility='collapsed', index=prompt_focus_idx)
339
+
340
+ # remove coloring from prompt
341
+ selected_prompt = self.text_coloring_remove(selected_prompt)
342
+ print('selected_prompt: ', selected_prompt)
343
+ st.session_state.prompt_idx_last_time = prompts.index(selected_prompt) if selected_prompt else 0
344
 
345
  if selected_prompt is None:
346
  # st.markdown(':orange[Please select a prompt above👆]')
 
457
  if select:
458
  st.session_state.selected_dict[item['prompt_id']].append(item['modelVersion_id'])
459
  self.remove_ranking_states(item['prompt_id'])
460
+
461
+ # add focus to session state
462
+ st.session_state.gallery_focus['tag'] = item['tag']
463
+ st.session_state.gallery_focus['prompt'] = item['prompt']
464
+
465
  st.experimental_rerun()
466
 
467
  # st.write(item)