Spaces:
Sleeping
Sleeping
update theshold
Browse files- app.py +101 -84
- data/download_script.py +3 -3
- data/roster/data-00000-of-00001.arrow +0 -3
- data/roster/dataset_info.json +0 -57
- data/roster/state.json +0 -13
app.py
CHANGED
@@ -26,7 +26,7 @@ def altair_histogram(hist_data, sort_by, mini, maxi):
|
|
26 |
chart = (
|
27 |
alt.Chart(hist_data)
|
28 |
.mark_bar(opacity=0.7, cornerRadius=2)
|
29 |
-
.encode(alt.X(f"{sort_by}:Q", bin=alt.Bin(maxbins=
|
30 |
# .add_selection(brushed)
|
31 |
# .properties(width=800, height=300)
|
32 |
)
|
@@ -84,28 +84,31 @@ class GalleryApp:
|
|
84 |
|
85 |
def gallery_standard(self, items, col_num, info):
|
86 |
rows = len(items) // col_num + 1
|
87 |
-
containers = [st.container() for _ in range(rows*2)]
|
|
|
88 |
for idx in range(0, len(items), col_num):
|
89 |
# assign one container for each row
|
90 |
-
row_idx = (idx // col_num) * 2
|
|
|
91 |
with containers[row_idx]:
|
92 |
cols = st.columns(col_num)
|
93 |
for j in range(col_num):
|
94 |
if idx + j < len(items):
|
95 |
with cols[j]:
|
96 |
# show image
|
97 |
-
image = self.images_ds[items.iloc[idx+j]['row_idx'].item()]['image']
|
98 |
|
99 |
-
st.image(image,
|
100 |
-
use_column_width=True,
|
101 |
-
)
|
102 |
|
103 |
-
# show checkbox
|
104 |
-
self.promptBook.loc[items.iloc[idx+j]['row_idx'].item(), 'checked'] = st.checkbox(
|
|
|
|
|
105 |
|
|
|
106 |
# show selected info
|
107 |
for key in info:
|
108 |
-
st.write(f"**{key}**: {items.iloc[idx+j][key]}")
|
109 |
|
110 |
# st.write(row_idx/2, idx+j, rows)
|
111 |
# extra_info = st.checkbox('Extra Info', key=f'extra_info_{idx+j}')
|
@@ -192,16 +195,19 @@ class GalleryApp:
|
|
192 |
return items, info, col_num
|
193 |
|
194 |
def selection_panel_2(self, items):
|
195 |
-
selecters = st.columns([1,
|
196 |
|
|
|
197 |
with selecters[0]:
|
198 |
-
sort_type = st.selectbox('Sort by', ['IDs and Names'
|
199 |
if sort_type == 'Scores':
|
200 |
sort_by = 'weighted_score_sum'
|
201 |
|
|
|
202 |
with selecters[1]:
|
203 |
if sort_type == 'IDs and Names':
|
204 |
-
sub_selecters = st.columns([3, 1
|
|
|
205 |
with sub_selecters[0]:
|
206 |
sort_by = st.selectbox('Sort by',
|
207 |
['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id'],
|
@@ -210,81 +216,89 @@ class GalleryApp:
|
|
210 |
continue_idx = 1
|
211 |
|
212 |
else:
|
213 |
-
|
|
|
|
|
|
|
|
|
214 |
|
215 |
with sub_selecters[0]:
|
216 |
-
clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=
|
217 |
with sub_selecters[1]:
|
218 |
-
rank_weight = st.number_input('
|
219 |
with sub_selecters[2]:
|
220 |
-
pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=
|
|
|
|
|
221 |
|
222 |
items.loc[:, 'weighted_score_sum'] = round(items['norm_clip'] * clip_weight + items['avg_rank'] * rank_weight + items[
|
223 |
'norm_pop'] * pop_weight, 4)
|
224 |
|
225 |
continue_idx = 3
|
226 |
|
227 |
-
|
228 |
with sub_selecters[continue_idx]:
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
#
|
241 |
-
|
242 |
-
# return checked items
|
243 |
-
items = items[items['checked'] == False].reset_index(drop=True)
|
244 |
-
|
245 |
-
elif filter == 'Unsafe':
|
246 |
-
# return unchecked items
|
247 |
-
items = items[items['checked'] == True].reset_index(drop=True)
|
248 |
-
print(items)
|
249 |
|
250 |
# draw a distribution histogram
|
251 |
if sort_type == 'Scores':
|
252 |
-
|
253 |
-
st.
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
st.
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
|
|
|
269 |
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
|
|
275 |
|
276 |
-
#
|
277 |
-
|
278 |
-
for i in info:
|
279 |
-
if '+' in i:
|
280 |
-
mentioned = i.split('+')
|
281 |
-
for m in mentioned:
|
282 |
-
if SCORE_NAME_MAPPING[m] not in mentioned_scores:
|
283 |
-
mentioned_scores.append(SCORE_NAME_MAPPING[m])
|
284 |
-
if len(mentioned_scores) > 0:
|
285 |
-
st.info(
|
286 |
-
f"**Note:** The scores {mentioned_scores} are normalized to [0, 1] for each score type, and then added together. The higher the score, the better the model.")
|
287 |
|
|
|
288 |
col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
|
289 |
|
290 |
return items, info, col_num
|
@@ -351,6 +365,7 @@ class GalleryApp:
|
|
351 |
unsafe_prompts['people'] = [53]
|
352 |
unsafe_prompts['art'] = [23]
|
353 |
unsafe_prompts['abstract'] = [10, 12]
|
|
|
354 |
|
355 |
if int(prompt_id.item()) in unsafe_prompts[tag]:
|
356 |
st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
|
@@ -358,19 +373,18 @@ class GalleryApp:
|
|
358 |
|
359 |
if safety_check:
|
360 |
items, info, col_num = self.selection_panel_2(items)
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
self.gallery_standard(items, col_num, info)
|
374 |
|
375 |
def reset_current_prompt(self, prompt_id):
|
376 |
# reset current prompt
|
@@ -393,11 +407,15 @@ class GalleryApp:
|
|
393 |
dataset = dataset.add_column('checked', checked_info)
|
394 |
|
395 |
# print('metadata dataset: ', dataset)
|
|
|
396 |
dataset.push_to_hub('NYUSHPRP/ModelCofferMetadata', split='train')
|
397 |
|
398 |
|
399 |
@st.cache_data
|
400 |
def load_hf_dataset():
|
|
|
|
|
|
|
401 |
# load from huggingface
|
402 |
roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train'))
|
403 |
promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
|
@@ -426,7 +444,6 @@ def load_hf_dataset():
|
|
426 |
|
427 |
|
428 |
if __name__ == '__main__':
|
429 |
-
login(token=os.environ.get("HF_TOKEN"))
|
430 |
st.set_page_config(layout="wide")
|
431 |
|
432 |
roster, promptBook, images_ds = load_hf_dataset()
|
|
|
26 |
chart = (
|
27 |
alt.Chart(hist_data)
|
28 |
.mark_bar(opacity=0.7, cornerRadius=2)
|
29 |
+
.encode(alt.X(f"{sort_by}:Q", bin=alt.Bin(maxbins=25)), y="count()")
|
30 |
# .add_selection(brushed)
|
31 |
# .properties(width=800, height=300)
|
32 |
)
|
|
|
84 |
|
85 |
def gallery_standard(self, items, col_num, info):
|
86 |
rows = len(items) // col_num + 1
|
87 |
+
# containers = [st.container() for _ in range(rows * 2)]
|
88 |
+
containers = [st.container() for _ in range(rows)]
|
89 |
for idx in range(0, len(items), col_num):
|
90 |
# assign one container for each row
|
91 |
+
# row_idx = (idx // col_num) * 2
|
92 |
+
row_idx = idx // col_num
|
93 |
with containers[row_idx]:
|
94 |
cols = st.columns(col_num)
|
95 |
for j in range(col_num):
|
96 |
if idx + j < len(items):
|
97 |
with cols[j]:
|
98 |
# show image
|
99 |
+
image = self.images_ds[items.iloc[idx + j]['row_idx'].item()]['image']
|
100 |
|
101 |
+
st.image(image, use_column_width=True)
|
|
|
|
|
102 |
|
103 |
+
# # show checkbox
|
104 |
+
# self.promptBook.loc[items.iloc[idx + j]['row_idx'].item(), 'checked'] = st.checkbox(
|
105 |
+
# 'Select', value=self.promptBook.loc[items.iloc[idx + j]['row_idx'].item(), 'checked'],
|
106 |
+
# key=f'select_{idx + j}')
|
107 |
|
108 |
+
st.write(idx+j)
|
109 |
# show selected info
|
110 |
for key in info:
|
111 |
+
st.write(f"**{key}**: {items.iloc[idx + j][key]}")
|
112 |
|
113 |
# st.write(row_idx/2, idx+j, rows)
|
114 |
# extra_info = st.checkbox('Extra Info', key=f'extra_info_{idx+j}')
|
|
|
195 |
return items, info, col_num
|
196 |
|
197 |
def selection_panel_2(self, items):
|
198 |
+
selecters = st.columns([1, 4])
|
199 |
|
200 |
+
# select sort type
|
201 |
with selecters[0]:
|
202 |
+
sort_type = st.selectbox('Sort by', ['Scores', 'IDs and Names'])
|
203 |
if sort_type == 'Scores':
|
204 |
sort_by = 'weighted_score_sum'
|
205 |
|
206 |
+
# select other options
|
207 |
with selecters[1]:
|
208 |
if sort_type == 'IDs and Names':
|
209 |
+
sub_selecters = st.columns([3, 1])
|
210 |
+
# select sort by
|
211 |
with sub_selecters[0]:
|
212 |
sort_by = st.selectbox('Sort by',
|
213 |
['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id'],
|
|
|
216 |
continue_idx = 1
|
217 |
|
218 |
else:
|
219 |
+
# add custom weights
|
220 |
+
sub_selecters = st.columns([1, 1, 1, 1])
|
221 |
+
|
222 |
+
if 'default_weights' not in st.session_state:
|
223 |
+
st.session_state.default_weights = [1.0, 1.0, 1.0]
|
224 |
|
225 |
with sub_selecters[0]:
|
226 |
+
clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=st.session_state.default_weights[0], step=0.1, help='the weight for normalized clip score')
|
227 |
with sub_selecters[1]:
|
228 |
+
rank_weight = st.number_input('Distinctiveness Weight', min_value=-100.0, max_value=100.0, value=st.session_state.default_weights[1], step=0.1, help='the weight for average rank')
|
229 |
with sub_selecters[2]:
|
230 |
+
pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.default_weights[2], step=0.1, help='the weight for normalized popularity score')
|
231 |
+
|
232 |
+
st.session_state.default_weights = [clip_weight, rank_weight, pop_weight]
|
233 |
|
234 |
items.loc[:, 'weighted_score_sum'] = round(items['norm_clip'] * clip_weight + items['avg_rank'] * rank_weight + items[
|
235 |
'norm_pop'] * pop_weight, 4)
|
236 |
|
237 |
continue_idx = 3
|
238 |
|
239 |
+
# select threshold
|
240 |
with sub_selecters[continue_idx]:
|
241 |
+
dist_threshold = st.number_input('Distinctiveness Threshold', min_value=0.0, max_value=1.0, value=0.84, step=0.01, help='Only show models with distinctiveness score lower than this threshold, set 1.0 to show all images')
|
242 |
+
items = items[items['avg_rank'] < dist_threshold].reset_index(drop=True)
|
243 |
+
|
244 |
+
# filter = st.selectbox('Filter', ['Safe', 'All', 'Unsafe'])
|
245 |
+
# print('filter', filter)
|
246 |
+
# # initialize unsafe_modelVersion_ids
|
247 |
+
# if filter == 'Safe':
|
248 |
+
# # return unchecked items
|
249 |
+
# items = items[items['checked'] == False].reset_index(drop=True)
|
250 |
+
#
|
251 |
+
# elif filter == 'Unsafe':
|
252 |
+
# # return checked items
|
253 |
+
# items = items[items['checked'] == True].reset_index(drop=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
|
255 |
# draw a distribution histogram
|
256 |
if sort_type == 'Scores':
|
257 |
+
try:
|
258 |
+
with st.expander('Show score distribution histogram and select score range'):
|
259 |
+
st.write('**Score distribution histogram**')
|
260 |
+
chart_space = st.container()
|
261 |
+
# st.write('Select the range of scores to show')
|
262 |
+
hist_data = pd.DataFrame(items[sort_by])
|
263 |
+
mini = hist_data[sort_by].min().item()
|
264 |
+
mini = mini//0.1 * 0.1
|
265 |
+
maxi = hist_data[sort_by].max().item()
|
266 |
+
maxi = maxi//0.1 * 0.1 + 0.1
|
267 |
+
st.write('**Select the range of scores to show**')
|
268 |
+
r = st.slider('Select the range of scores to show', min_value=mini, max_value=maxi, value=(mini, maxi), step=0.05, label_visibility='collapsed')
|
269 |
+
with chart_space:
|
270 |
+
st.altair_chart(altair_histogram(hist_data, sort_by, r[0], r[1]), use_container_width=True)
|
271 |
+
# event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
|
272 |
+
# r = event_dict.get(sort_by)
|
273 |
+
if r:
|
274 |
+
items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
|
275 |
+
# st.write(r)
|
276 |
+
except:
|
277 |
+
pass
|
278 |
+
|
279 |
+
display_options = st.columns([1, 4])
|
280 |
+
|
281 |
+
with display_options[0]:
|
282 |
+
# select order
|
283 |
+
order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
|
284 |
+
if order == 'Ascending':
|
285 |
+
order = True
|
286 |
+
else:
|
287 |
+
order = False
|
288 |
|
289 |
+
with display_options[1]:
|
290 |
|
291 |
+
# select info to show
|
292 |
+
info = st.multiselect('Show Info',
|
293 |
+
['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
|
294 |
+
'modelVersion_name', 'modelVersion_id', 'clip+rank', 'clip+pop', 'rank+pop',
|
295 |
+
'clip+rank+pop', 'weighted_score_sum'],
|
296 |
+
default=sort_by)
|
297 |
|
298 |
+
# apply sorting to dataframe
|
299 |
+
items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
|
301 |
+
# select number of columns
|
302 |
col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
|
303 |
|
304 |
return items, info, col_num
|
|
|
365 |
unsafe_prompts['people'] = [53]
|
366 |
unsafe_prompts['art'] = [23]
|
367 |
unsafe_prompts['abstract'] = [10, 12]
|
368 |
+
unsafe_prompts['food'] = [34]
|
369 |
|
370 |
if int(prompt_id.item()) in unsafe_prompts[tag]:
|
371 |
st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
|
|
|
373 |
|
374 |
if safety_check:
|
375 |
items, info, col_num = self.selection_panel_2(items)
|
376 |
+
self.gallery_standard(items, col_num, info)
|
377 |
+
|
378 |
+
# with st.form(key=f'{prompt_id}', clear_on_submit=True):
|
379 |
+
# buttons = st.columns([1, 1, 1])
|
380 |
+
# with buttons[0]:
|
381 |
+
# submit = st.form_submit_button('Save selections', on_click=self.save_checked, use_container_width=True, type='primary')
|
382 |
+
# with buttons[1]:
|
383 |
+
# submit = st.form_submit_button('Reset current prompt', on_click=self.reset_current_prompt, kwargs={'prompt_id': prompt_id} , use_container_width=True)
|
384 |
+
# with buttons[2]:
|
385 |
+
# submit = st.form_submit_button('Reset all selections', on_click=self.reset_all, use_container_width=True)
|
386 |
+
#
|
387 |
+
# self.gallery_standard(items, col_num, info)
|
|
|
388 |
|
389 |
def reset_current_prompt(self, prompt_id):
|
390 |
# reset current prompt
|
|
|
407 |
dataset = dataset.add_column('checked', checked_info)
|
408 |
|
409 |
# print('metadata dataset: ', dataset)
|
410 |
+
st.cache_data.clear()
|
411 |
dataset.push_to_hub('NYUSHPRP/ModelCofferMetadata', split='train')
|
412 |
|
413 |
|
414 |
@st.cache_data
|
415 |
def load_hf_dataset():
|
416 |
+
# login to huggingface
|
417 |
+
login(token=os.environ.get("HF_TOKEN"))
|
418 |
+
|
419 |
# load from huggingface
|
420 |
roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train'))
|
421 |
promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
|
|
|
444 |
|
445 |
|
446 |
if __name__ == '__main__':
|
|
|
447 |
st.set_page_config(layout="wide")
|
448 |
|
449 |
roster, promptBook, images_ds = load_hf_dataset()
|
data/download_script.py
CHANGED
@@ -5,9 +5,9 @@ def main():
|
|
5 |
promptbook = load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train')
|
6 |
print(promptbook)
|
7 |
promptbook.save_to_disk('./promptbook')
|
8 |
-
|
9 |
-
roster = load_dataset('NYUSHPRP/ModelCofferRoster', split='train')
|
10 |
-
roster.save_to_disk('./roster')
|
11 |
|
12 |
|
13 |
def load():
|
|
|
5 |
promptbook = load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train')
|
6 |
print(promptbook)
|
7 |
promptbook.save_to_disk('./promptbook')
|
8 |
+
#
|
9 |
+
# roster = load_dataset('NYUSHPRP/ModelCofferRoster', split='train')
|
10 |
+
# roster.save_to_disk('./roster')
|
11 |
|
12 |
|
13 |
def load():
|
data/roster/data-00000-of-00001.arrow
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:0d92d1f86f02823ca64b7d88ffbb4c03a1ca8fe9990a54e37b9a1d9171782fca
|
3 |
-
size 147952
|
|
|
|
|
|
|
|
data/roster/dataset_info.json
DELETED
@@ -1,57 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"citation": "",
|
3 |
-
"dataset_size": 145934,
|
4 |
-
"description": "",
|
5 |
-
"download_checksums": {
|
6 |
-
"https://huggingface.co/datasets/NYUSHPRP/ModelCofferRoster/resolve/ca9efb0b73c3383dfb5bc9fff380b068d468bfde/data/train-00000-of-00001-0fd3ef44b360ac99.parquet": {
|
7 |
-
"num_bytes": 27979,
|
8 |
-
"checksum": null
|
9 |
-
}
|
10 |
-
},
|
11 |
-
"download_size": 27979,
|
12 |
-
"features": {
|
13 |
-
"tag": {
|
14 |
-
"dtype": "string",
|
15 |
-
"_type": "Value"
|
16 |
-
},
|
17 |
-
"model_name": {
|
18 |
-
"dtype": "string",
|
19 |
-
"_type": "Value"
|
20 |
-
},
|
21 |
-
"model_id": {
|
22 |
-
"dtype": "int64",
|
23 |
-
"_type": "Value"
|
24 |
-
},
|
25 |
-
"modelVersion_name": {
|
26 |
-
"dtype": "string",
|
27 |
-
"_type": "Value"
|
28 |
-
},
|
29 |
-
"modelVersion_id": {
|
30 |
-
"dtype": "int64",
|
31 |
-
"_type": "Value"
|
32 |
-
},
|
33 |
-
"modelVersion_url": {
|
34 |
-
"dtype": "string",
|
35 |
-
"_type": "Value"
|
36 |
-
},
|
37 |
-
"modelVersion_trainedWords": {
|
38 |
-
"dtype": "string",
|
39 |
-
"_type": "Value"
|
40 |
-
},
|
41 |
-
"model_download_count": {
|
42 |
-
"dtype": "int64",
|
43 |
-
"_type": "Value"
|
44 |
-
}
|
45 |
-
},
|
46 |
-
"homepage": "",
|
47 |
-
"license": "",
|
48 |
-
"size_in_bytes": 173913,
|
49 |
-
"splits": {
|
50 |
-
"train": {
|
51 |
-
"name": "train",
|
52 |
-
"num_bytes": 145934,
|
53 |
-
"num_examples": 1059,
|
54 |
-
"dataset_name": "parquet"
|
55 |
-
}
|
56 |
-
}
|
57 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/roster/state.json
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"_data_files": [
|
3 |
-
{
|
4 |
-
"filename": "data-00000-of-00001.arrow"
|
5 |
-
}
|
6 |
-
],
|
7 |
-
"_fingerprint": "9508df8b007debc4",
|
8 |
-
"_format_columns": null,
|
9 |
-
"_format_kwargs": {},
|
10 |
-
"_format_type": null,
|
11 |
-
"_output_all_columns": false,
|
12 |
-
"_split": "train"
|
13 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|