Ricercar commited on
Commit
c389acc
1 Parent(s): 38c92eb

add manual prompt filter

Browse files
Files changed (2) hide show
  1. data/curation.json +14 -0
  2. pages/Gallery.py +8 -1
data/curation.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "abstract": [1,3],
3
+ "animal": [],
4
+ "architecture": [],
5
+ "art": [],
6
+ "artifact": [],
7
+ "food": [],
8
+ "illustration": [39],
9
+ "people": [49,50,51,62,54,56,48,60],
10
+ "produce & plant": [],
11
+ "scenery": [],
12
+ "vehicle": [],
13
+ "world knowledge": [84,83,85]
14
+ }
pages/Gallery.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import json
2
  import os
3
  import requests
@@ -187,7 +188,7 @@ class GalleryApp:
187
 
188
  # set focus tag and prompt index if exists
189
  if st.session_state.gallery_focus['tag'] is None:
190
- tag_focus_idx = 5
191
  else:
192
  tag_focus_idx = prompt_tags.index(st.session_state.gallery_focus['tag'])
193
 
@@ -591,6 +592,12 @@ def load_hf_dataset(show_NSFW=False):
591
  # add column to record current row index
592
  promptBook.loc[:, 'row_idx'] = promptBook.index
593
 
 
 
 
 
 
 
594
  # apply a nsfw filter
595
  if not show_NSFW:
596
  promptBook = promptBook[promptBook['norm_nsfw'] <= 0.8].reset_index(drop=True)
 
1
+ import itertools
2
  import json
3
  import os
4
  import requests
 
188
 
189
  # set focus tag and prompt index if exists
190
  if st.session_state.gallery_focus['tag'] is None:
191
+ tag_focus_idx = 0
192
  else:
193
  tag_focus_idx = prompt_tags.index(st.session_state.gallery_focus['tag'])
194
 
 
592
  # add column to record current row index
593
  promptBook.loc[:, 'row_idx'] = promptBook.index
594
 
595
+ # apply curation filter
596
+ prompt_to_hide = json.load(open('./data/curation.json', 'r'))
597
+ prompt_to_hide = list(itertools.chain.from_iterable(prompt_to_hide.values()))
598
+ print('prompt to hide: ', prompt_to_hide)
599
+ promptBook = promptBook[~promptBook['prompt_id'].isin(prompt_to_hide)].reset_index(drop=True)
600
+
601
  # apply a nsfw filter
602
  if not show_NSFW:
603
  promptBook = promptBook[promptBook['norm_nsfw'] <= 0.8].reset_index(drop=True)