Ricercar commited on
Commit
0ae586e
1 Parent(s): d65b5cb

beta version of new gallery view

Browse files
Archive/Gallery_beta0913.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+
4
+ import altair as alt
5
+ import numpy as np
6
+ import pandas as pd
7
+ import streamlit as st
8
+ import streamlit.components.v1 as components
9
+
10
+ from bs4 import BeautifulSoup
11
+ from datasets import load_dataset, Dataset, load_from_disk
12
+ from huggingface_hub import login
13
+ from streamlit_agraph import agraph, Node, Edge, Config
14
+ from streamlit_extras.switch_page_button import switch_page
15
+ from sklearn.svm import LinearSVC
16
+
17
+ SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'msq_score', 'pop': 'model_download_count'}
18
+
19
+
20
+ class GalleryApp:
21
+ def __init__(self, promptBook, images_ds):
22
+ self.promptBook = promptBook
23
+ self.images_ds = images_ds
24
+
25
+ def gallery_standard(self, items, col_num, info):
26
+ rows = len(items) // col_num + 1
27
+ containers = [st.container() for _ in range(rows)]
28
+ for idx in range(0, len(items), col_num):
29
+ row_idx = idx // col_num
30
+ with containers[row_idx]:
31
+ cols = st.columns(col_num)
32
+ for j in range(col_num):
33
+ if idx + j < len(items):
34
+ with cols[j]:
35
+ # show image
36
+ # image = self.images_ds[items.iloc[idx + j]['row_idx'].item()]['image']
37
+ image = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{items.iloc[idx + j]['image_id']}.png"
38
+ st.image(image, use_column_width=True)
39
+
40
+ # handel checkbox information
41
+ prompt_id = items.iloc[idx + j]['prompt_id']
42
+ modelVersion_id = items.iloc[idx + j]['modelVersion_id']
43
+
44
+ check_init = True if modelVersion_id in st.session_state.selected_dict.get(prompt_id, []) else False
45
+
46
+ # st.write("Position: ", idx + j)
47
+
48
+ # show checkbox
49
+ st.checkbox('Select', key=f'select_{prompt_id}_{modelVersion_id}', value=check_init)
50
+
51
+ # show selected info
52
+ for key in info:
53
+ st.write(f"**{key}**: {items.iloc[idx + j][key]}")
54
+
55
+ def gallery_graph(self, items):
56
+ items = load_tsne_coordinates(items)
57
+
58
+ # sort items to be popularity from low to high, so that most popular ones will be on the top
59
+ items = items.sort_values(by=['model_download_count'], ascending=True).reset_index(drop=True)
60
+
61
+ scale = 50
62
+ items.loc[:, 'x'] = items['x'] * scale
63
+ items.loc[:, 'y'] = items['y'] * scale
64
+
65
+ nodes = []
66
+ edges = []
67
+
68
+ for idx in items.index:
69
+ # if items.loc[idx, 'modelVersion_id'] in st.session_state.selected_dict.get(items.loc[idx, 'prompt_id'], 0):
70
+ # opacity = 0.2
71
+ # else:
72
+ # opacity = 1.0
73
+
74
+ nodes.append(Node(id=items.loc[idx, 'image_id'],
75
+ # label=str(items.loc[idx, 'model_name']),
76
+ title=f"model name: {items.loc[idx, 'model_name']}\nmodelVersion name: {items.loc[idx, 'modelVersion_name']}\nclip score: {items.loc[idx, 'clip_score']}\nmcos score: {items.loc[idx, 'mcos_score']}\npopularity: {items.loc[idx, 'model_download_count']}",
77
+ size=20,
78
+ shape='image',
79
+ image=f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{items.loc[idx, 'image_id']}.png",
80
+ x=items.loc[idx, 'x'].item(),
81
+ y=items.loc[idx, 'y'].item(),
82
+ # fixed=True,
83
+ color={'background': '#E0E0E1', 'border': '#ffffff', 'highlight': {'border': '#F04542'}},
84
+ # opacity=opacity,
85
+ shadow={'enabled': True, 'color': 'rgba(0,0,0,0.4)', 'size': 10, 'x': 1, 'y': 1},
86
+ borderWidth=2,
87
+ shapeProperties={'useBorderWithImage': True},
88
+ )
89
+ )
90
+
91
+ config = Config(width='100%',
92
+ height='600',
93
+ directed=True,
94
+ physics=False,
95
+ hierarchical=False,
96
+ interaction={'navigationButtons': True, 'dragNodes': False, 'multiselect': False},
97
+ # **kwargs
98
+ )
99
+
100
+ return agraph(nodes=nodes,
101
+ edges=edges,
102
+ config=config,
103
+ )
104
+
105
+ def selection_panel(self, items):
106
+ # temperal function
107
+
108
+ selecters = st.columns([1, 4])
109
+
110
+ if 'score_weights' not in st.session_state:
111
+ st.session_state.score_weights = [1.0, 0.8, 0.2, 0.8]
112
+
113
+ # select sort type
114
+ with selecters[0]:
115
+ sort_type = st.selectbox('Sort by', ['Scores', 'IDs and Names'])
116
+ if sort_type == 'Scores':
117
+ sort_by = 'weighted_score_sum'
118
+
119
+ # select other options
120
+ with selecters[1]:
121
+ if sort_type == 'IDs and Names':
122
+ sub_selecters = st.columns([3, 1])
123
+ # select sort by
124
+ with sub_selecters[0]:
125
+ sort_by = st.selectbox('Sort by',
126
+ ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id', 'norm_nsfw'],
127
+ label_visibility='hidden')
128
+
129
+ continue_idx = 1
130
+
131
+ else:
132
+ # add custom weights
133
+ sub_selecters = st.columns([1, 1, 1, 1])
134
+
135
+ with sub_selecters[0]:
136
+ clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=1.0, step=0.1, help='the weight for normalized clip score')
137
+ with sub_selecters[1]:
138
+ mcos_weight = st.number_input('Dissimilarity Weight', min_value=-100.0, max_value=100.0, value=0.8, step=0.1, help='the weight for m(eam) s(imilarity) q(antile) score for measuring distinctiveness')
139
+ with sub_selecters[2]:
140
+ pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=0.2, step=0.1, help='the weight for normalized popularity score')
141
+
142
+ items.loc[:, 'weighted_score_sum'] = round(items[f'norm_clip'] * clip_weight + items[f'norm_mcos'] * mcos_weight + items[
143
+ 'norm_pop'] * pop_weight, 4)
144
+
145
+ continue_idx = 3
146
+
147
+ # save latest weights
148
+ st.session_state.score_weights[0] = round(clip_weight, 2)
149
+ st.session_state.score_weights[1] = round(mcos_weight, 2)
150
+ st.session_state.score_weights[2] = round(pop_weight, 2)
151
+
152
+ # select threshold
153
+ with sub_selecters[continue_idx]:
154
+ nsfw_threshold = st.number_input('NSFW Score Threshold', min_value=0.0, max_value=1.0, value=0.8, step=0.01, help='Only show models with nsfw score lower than this threshold, set 1.0 to show all images')
155
+ items = items[items['norm_nsfw'] <= nsfw_threshold].reset_index(drop=True)
156
+
157
+ # save latest threshold
158
+ st.session_state.score_weights[3] = nsfw_threshold
159
+
160
+ # draw a distribution histogram
161
+ if sort_type == 'Scores':
162
+ try:
163
+ with st.expander('Show score distribution histogram and select score range'):
164
+ st.write('**Score distribution histogram**')
165
+ chart_space = st.container()
166
+ # st.write('Select the range of scores to show')
167
+ hist_data = pd.DataFrame(items[sort_by])
168
+ mini = hist_data[sort_by].min().item()
169
+ mini = mini//0.1 * 0.1
170
+ maxi = hist_data[sort_by].max().item()
171
+ maxi = maxi//0.1 * 0.1 + 0.1
172
+ st.write('**Select the range of scores to show**')
173
+ r = st.slider('Select the range of scores to show', min_value=mini, max_value=maxi, value=(mini, maxi), step=0.05, label_visibility='collapsed')
174
+ with chart_space:
175
+ st.altair_chart(altair_histogram(hist_data, sort_by, r[0], r[1]), use_container_width=True)
176
+ # event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
177
+ # r = event_dict.get(sort_by)
178
+ if r:
179
+ items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
180
+ # st.write(r)
181
+ except:
182
+ pass
183
+
184
+ display_options = st.columns([1, 4])
185
+
186
+ with display_options[0]:
187
+ # select order
188
+ order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
189
+ if order == 'Ascending':
190
+ order = True
191
+ else:
192
+ order = False
193
+
194
+ with display_options[1]:
195
+
196
+ # select info to show
197
+ info = st.multiselect('Show Info',
198
+ ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id',
199
+ 'weighted_score_sum', 'model_download_count', 'clip_score', 'mcos_score',
200
+ 'nsfw_score', 'norm_nsfw'],
201
+ default=sort_by)
202
+
203
+ # apply sorting to dataframe
204
+ items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
205
+
206
+ # select number of columns
207
+ col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
208
+
209
+ return items, info, col_num
210
+
211
+ def sidebar(self):
212
+ with st.sidebar:
213
+ prompt_tags = self.promptBook['tag'].unique()
214
+ # sort tags by alphabetical order
215
+ prompt_tags = np.sort(prompt_tags)[::1]
216
+
217
+ tag = st.selectbox('Select a tag', prompt_tags, index=5)
218
+
219
+ items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
220
+
221
+ prompts = np.sort(items['prompt'].unique())[::1]
222
+
223
+ selected_prompt = st.selectbox('Select prompt', prompts, index=3)
224
+
225
+ mode = st.radio('Select a mode', ['Gallery', 'Graph'], horizontal=True, index=1)
226
+
227
+ items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
228
+ prompt_id = items['prompt_id'].unique()[0]
229
+ note = items['note'].unique()[0]
230
+
231
+ # show source
232
+ if isinstance(note, str):
233
+ if note.isdigit():
234
+ st.caption(f"`Source: civitai`")
235
+ else:
236
+ st.caption(f"`Source: {note}`")
237
+ else:
238
+ st.caption("`Source: Parti-prompts`")
239
+
240
+ # show image metadata
241
+ image_metadatas = ['prompt', 'negativePrompt', 'sampler', 'cfgScale', 'size', 'seed']
242
+ for key in image_metadatas:
243
+ label = ' '.join(key.split('_')).capitalize()
244
+ st.write(f"**{label}**")
245
+ if items[key][0] == ' ':
246
+ st.write('`None`')
247
+ else:
248
+ st.caption(f"{items[key][0]}")
249
+
250
+ # for note as civitai image id, add civitai reference
251
+ if isinstance(note, str) and note.isdigit():
252
+ try:
253
+ st.write(f'**[Civitai Reference](https://civitai.com/images/{note})**')
254
+ res = requests.get(f'https://civitai.com/images/{note}')
255
+ # st.write(res.text)
256
+ soup = BeautifulSoup(res.text, 'html.parser')
257
+ image_section = soup.find('div', {'class': 'mantine-12rlksp'})
258
+ image_url = image_section.find('img')['src']
259
+ st.image(image_url, use_column_width=True)
260
+ except:
261
+ pass
262
+
263
+ return prompt_tags, tag, prompt_id, items, mode
264
+
265
+ def app(self):
266
+ st.title('Model Visualization and Retrieval')
267
+ st.write('This is a gallery of images generated by the models')
268
+
269
+ prompt_tags, tag, prompt_id, items, mode = self.sidebar()
270
+ # items, info, col_num = self.selection_panel(items)
271
+
272
+ # subset = st.radio('Select a subset', ['All', 'Selected Only'], index=0, horizontal=True)
273
+ # try:
274
+ # if subset == 'Selected Only':
275
+ # items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
276
+ # except:
277
+ # pass
278
+
279
+ # add safety check for some prompts
280
+ safety_check = True
281
+ unsafe_prompts = {}
282
+ # initialize unsafe prompts
283
+ for prompt_tag in prompt_tags:
284
+ unsafe_prompts[prompt_tag] = []
285
+ # manually add unsafe prompts
286
+ unsafe_prompts['world knowledge'] = [83]
287
+ unsafe_prompts['abstract'] = [1, 3]
288
+
289
+ if int(prompt_id.item()) in unsafe_prompts[tag]:
290
+ st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
291
+ safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
292
+
293
+ if safety_check:
294
+ if mode == 'Gallery':
295
+ self.gallery_mode(prompt_id, items)
296
+ elif mode == 'Graph':
297
+ self.graph_mode(prompt_id, items)
298
+
299
+
300
+ def graph_mode(self, prompt_id, items):
301
+ graph_cols = st.columns([3, 1])
302
+ prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}",
303
+ disabled=False, key=f'{prompt_id}')
304
+ if prompt:
305
+ switch_page("ranking")
306
+
307
+ with graph_cols[0]:
308
+ graph_space = st.empty()
309
+
310
+ with graph_space.container():
311
+ return_value = self.gallery_graph(items)
312
+
313
+ with graph_cols[1]:
314
+ if return_value:
315
+ with st.form(key=f'{prompt_id}'):
316
+ image_url = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{return_value}.png"
317
+
318
+ st.image(image_url)
319
+
320
+ item = items[items['image_id'] == return_value].reset_index(drop=True).iloc[0]
321
+ modelVersion_id = item['modelVersion_id']
322
+
323
+ # handle selection
324
+ if 'selected_dict' in st.session_state:
325
+ if item['prompt_id'] not in st.session_state.selected_dict:
326
+ st.session_state.selected_dict[item['prompt_id']] = []
327
+
328
+ if modelVersion_id in st.session_state.selected_dict[item['prompt_id']]:
329
+ checked = True
330
+ else:
331
+ checked = False
332
+
333
+ if checked:
334
+ # deselect = st.button('Deselect', key=f'select_{item["prompt_id"]}_{item["modelVersion_id"]}', use_container_width=True)
335
+ deselect = st.form_submit_button('Deselect', use_container_width=True)
336
+ if deselect:
337
+ st.session_state.selected_dict[item['prompt_id']].remove(item['modelVersion_id'])
338
+ self.remove_ranking_states(item['prompt_id'])
339
+ st.experimental_rerun()
340
+
341
+ else:
342
+ # select = st.button('Select', key=f'select_{item["prompt_id"]}_{item["modelVersion_id"]}', use_container_width=True, type='primary')
343
+ select = st.form_submit_button('Select', use_container_width=True, type='primary')
344
+ if select:
345
+ st.session_state.selected_dict[item['prompt_id']].append(item['modelVersion_id'])
346
+ self.remove_ranking_states(item['prompt_id'])
347
+ st.experimental_rerun()
348
+
349
+ # st.write(item)
350
+ infos = ['model_name', 'modelVersion_name', 'model_download_count', 'clip_score', 'mcos_score',
351
+ 'nsfw_score']
352
+
353
+ infos_df = item[infos]
354
+ # rename columns
355
+ infos_df = infos_df.rename(index={'model_name': 'Model', 'modelVersion_name': 'Version', 'model_download_count': 'Downloads', 'clip_score': 'Clip Score', 'mcos_score': 'mcos Score', 'nsfw_score': 'NSFW Score'})
356
+ st.table(infos_df)
357
+
358
+ # for info in infos:
359
+ # st.write(f"**{info}**:")
360
+ # st.write(item[info])
361
+
362
+ else:
363
+ st.info('Please click on an image to show')
364
+
365
+
366
+ def gallery_mode(self, prompt_id, items):
367
+ items, info, col_num = self.selection_panel(items)
368
+
369
+ if 'selected_dict' in st.session_state:
370
+ # st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
371
+ dynamic_weight_options = ['Grid Search', 'SVM', 'Greedy']
372
+ dynamic_weight_panel = st.columns(len(dynamic_weight_options))
373
+
374
+ if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
375
+ btn_disable = False
376
+ else:
377
+ btn_disable = True
378
+
379
+ for i in range(len(dynamic_weight_options)):
380
+ method = dynamic_weight_options[i]
381
+ with dynamic_weight_panel[i]:
382
+ btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, method))
383
+
384
+ prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}", disabled=False, key=f'{prompt_id}')
385
+ if prompt:
386
+ switch_page("ranking")
387
+
388
+ with st.form(key=f'{prompt_id}'):
389
+ # buttons = st.columns([1, 1, 1])
390
+ buttons_space = st.columns([1, 1, 1, 1])
391
+ gallery_space = st.empty()
392
+
393
+ with buttons_space[0]:
394
+ continue_btn = st.form_submit_button('Confirm Selection', use_container_width=True, type='primary')
395
+ if continue_btn:
396
+ self.submit_actions('Continue', prompt_id)
397
+
398
+ with buttons_space[1]:
399
+ select_btn = st.form_submit_button('Select All', use_container_width=True)
400
+ if select_btn:
401
+ self.submit_actions('Select', prompt_id)
402
+
403
+ with buttons_space[2]:
404
+ deselect_btn = st.form_submit_button('Deselect All', use_container_width=True)
405
+ if deselect_btn:
406
+ self.submit_actions('Deselect', prompt_id)
407
+
408
+ with buttons_space[3]:
409
+ refresh_btn = st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True)
410
+
411
+ with gallery_space.container():
412
+ with st.spinner('Loading images...'):
413
+ self.gallery_standard(items, col_num, info)
414
+
415
+ st.info("Don't forget to scroll back to top and click the 'Confirm Selection' button to save your selection!!!")
416
+
417
+
418
+
419
+ def submit_actions(self, status, prompt_id):
420
+ # remove counter from session state
421
+ # st.session_state.pop('counter', None)
422
+ self.remove_ranking_states('prompt_id')
423
+ if status == 'Select':
424
+ modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
425
+ st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
426
+ print(st.session_state.selected_dict, 'select')
427
+ st.experimental_rerun()
428
+ elif status == 'Deselect':
429
+ st.session_state.selected_dict[prompt_id] = []
430
+ print(st.session_state.selected_dict, 'deselect')
431
+ st.experimental_rerun()
432
+ # self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
433
+ elif status == 'Continue':
434
+ st.session_state.selected_dict[prompt_id] = []
435
+ for key in st.session_state:
436
+ keys = key.split('_')
437
+ if keys[0] == 'select' and keys[1] == str(prompt_id):
438
+ if st.session_state[key]:
439
+ st.session_state.selected_dict[prompt_id].append(int(keys[2]))
440
+ # switch_page("ranking")
441
+ print(st.session_state.selected_dict, 'continue')
442
+ st.experimental_rerun()
443
+
444
+ def dynamic_weight(self, prompt_id, items, method='Grid Search'):
445
+ selected = items[
446
+ items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
447
+ optimal_weight = [0, 0, 0]
448
+
449
+ if method == 'Grid Search':
450
+ # grid search method
451
+ top_ranking = len(items) * len(selected)
452
+
453
+ for clip_weight in np.arange(-1, 1, 0.1):
454
+ for mcos_weight in np.arange(-1, 1, 0.1):
455
+ for pop_weight in np.arange(-1, 1, 0.1):
456
+
457
+ weight_all = clip_weight*items[f'norm_clip'] + mcos_weight*items[f'norm_mcos'] + pop_weight*items['norm_pop']
458
+ weight_all_sorted = weight_all.sort_values(ascending=False).reset_index(drop=True)
459
+ # print('weight_all_sorted:', weight_all_sorted)
460
+ weight_selected = clip_weight*selected[f'norm_clip'] + mcos_weight*selected[f'norm_mcos'] + pop_weight*selected['norm_pop']
461
+
462
+ # get the index of values of weight_selected in weight_all_sorted
463
+ rankings = []
464
+ for weight in weight_selected:
465
+ rankings.append(weight_all_sorted.index[weight_all_sorted == weight].tolist()[0])
466
+ if sum(rankings) <= top_ranking:
467
+ top_ranking = sum(rankings)
468
+ print('current top ranking:', top_ranking, rankings)
469
+ optimal_weight = [clip_weight, mcos_weight, pop_weight]
470
+ print('optimal weight:', optimal_weight)
471
+
472
+ elif method == 'SVM':
473
+ # svm method
474
+ print('start svm method')
475
+ # get residual dataframe that contains models not selected
476
+ residual = items[~items['modelVersion_id'].isin(selected['modelVersion_id'])].reset_index(drop=True)
477
+ residual = residual[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
478
+ residual = residual.to_numpy()
479
+ selected = selected[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
480
+ selected = selected.to_numpy()
481
+
482
+ y = np.concatenate((np.full((len(selected), 1), -1), np.full((len(residual), 1), 1)), axis=0).ravel()
483
+ X = np.concatenate((selected, residual), axis=0)
484
+
485
+ # fit svm model, and get parameters for the hyperplane
486
+ clf = LinearSVC(random_state=0, C=1.0, fit_intercept=False, dual='auto')
487
+ clf.fit(X, y)
488
+ optimal_weight = clf.coef_[0].tolist()
489
+ print('optimal weight:', optimal_weight)
490
+ pass
491
+
492
+ elif method == 'Greedy':
493
+ for idx in selected.index:
494
+ # find which score is the highest, clip, mcos, or pop
495
+ clip_score = selected.loc[idx, 'norm_clip_crop']
496
+ mcos_score = selected.loc[idx, 'norm_mcos_crop']
497
+ pop_score = selected.loc[idx, 'norm_pop']
498
+ if clip_score >= mcos_score and clip_score >= pop_score:
499
+ optimal_weight[0] += 1
500
+ elif mcos_score >= clip_score and mcos_score >= pop_score:
501
+ optimal_weight[1] += 1
502
+ elif pop_score >= clip_score and pop_score >= mcos_score:
503
+ optimal_weight[2] += 1
504
+
505
+ # normalize optimal_weight
506
+ optimal_weight = [round(weight/len(selected), 2) for weight in optimal_weight]
507
+ print('optimal weight:', optimal_weight)
508
+ print('optimal weight:', optimal_weight)
509
+
510
+ st.session_state.score_weights[0: 3] = optimal_weight
511
+
512
+
513
+ def remove_ranking_states(self, prompt_id):
514
+ # for drag sort
515
+ try:
516
+ st.session_state.counter[prompt_id] = 0
517
+ st.session_state.ranking[prompt_id] = {}
518
+ print('remove ranking states')
519
+ except:
520
+ print('no sort ranking states to remove')
521
+
522
+ # for battles
523
+ try:
524
+ st.session_state.pointer[prompt_id] = {'left': 0, 'right': 1}
525
+ print('remove battles states')
526
+ except:
527
+ print('no battles states to remove')
528
+
529
+ # for page progress
530
+ try:
531
+ st.session_state.progress[prompt_id] = 'ranking'
532
+ print('reset page progress states')
533
+ except:
534
+ print('no page progress states to be reset')
535
+
536
+
537
+ # hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
538
+ @st.cache_resource
539
+ def altair_histogram(hist_data, sort_by, mini, maxi):
540
+ brushed = alt.selection_interval(encodings=['x'], name="brushed")
541
+
542
+ chart = (
543
+ alt.Chart(hist_data)
544
+ .mark_bar(opacity=0.7, cornerRadius=2)
545
+ .encode(alt.X(f"{sort_by}:Q", bin=alt.Bin(maxbins=25)), y="count()")
546
+ # .add_selection(brushed)
547
+ # .properties(width=800, height=300)
548
+ )
549
+
550
+ # Create a transparent rectangle for highlighting the range
551
+ highlight = (
552
+ alt.Chart(pd.DataFrame({'x1': [mini], 'x2': [maxi]}))
553
+ .mark_rect(opacity=0.3)
554
+ .encode(x='x1', x2='x2')
555
+ # .properties(width=800, height=300)
556
+ )
557
+
558
+ # Layer the chart and the highlight rectangle
559
+ layered_chart = alt.layer(chart, highlight)
560
+
561
+ return layered_chart
562
+
563
+
564
+ @st.cache_data
565
+ def load_hf_dataset():
566
+ # login to huggingface
567
+ login(token=os.environ.get("HF_TOKEN"))
568
+
569
+ # load from huggingface
570
+ roster = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Roster', split='train'))
571
+ promptBook = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Metadata', split='train'))
572
+ # images_ds = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
573
+ images_ds = None # set to None for now since we use s3 bucket to store images
574
+
575
+ # # process dataset
576
+ # roster = roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
577
+ # 'model_download_count']].drop_duplicates().reset_index(drop=True)
578
+
579
+ # add 'custom_score_weights' column to promptBook if not exist
580
+ if 'weighted_score_sum' not in promptBook.columns:
581
+ promptBook.loc[:, 'weighted_score_sum'] = 0
582
+
583
+ # merge roster and promptbook
584
+ promptBook = promptBook.merge(roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']],
585
+ on=['model_id', 'modelVersion_id'], how='left')
586
+
587
+ # add column to record current row index
588
+ promptBook.loc[:, 'row_idx'] = promptBook.index
589
+
590
+ # apply a nsfw filter
591
+ promptBook = promptBook[promptBook['nsfw_score'] <= 0.84].reset_index(drop=True)
592
+
593
+ # add a column that adds up 'norm_clip', 'norm_mcos', and 'norm_pop'
594
+ score_weights = [1.0, 0.8, 0.2]
595
+ promptBook.loc[:, 'total_score'] = round(promptBook['norm_clip'] * score_weights[0] + promptBook['norm_mcos'] * score_weights[1] + promptBook['norm_pop'] * score_weights[2], 4)
596
+
597
+ return roster, promptBook, images_ds
598
+
599
+ @st.cache_data
600
+ def load_tsne_coordinates(items):
601
+ # load tsne coordinates
602
+ tsne_df = pd.read_parquet('./data/feats_tsne.parquet')
603
+
604
+ # print(tsne_df['modelVersion_id'].dtype)
605
+
606
+ print('before merge:', items)
607
+ items = items.merge(tsne_df, on=['modelVersion_id', 'prompt_id'], how='left')
608
+ print('after merge:', items)
609
+ return items
610
+
611
+
612
+ if __name__ == "__main__":
613
+ st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
614
+
615
+ if 'user_id' not in st.session_state:
616
+ st.warning('Please log in first.')
617
+ home_btn = st.button('Go to Home Page')
618
+ if home_btn:
619
+ switch_page("home")
620
+ else:
621
+ # st.write('You have already logged in as ' + st.session_state.user_id[0])
622
+ roster, promptBook, images_ds = load_hf_dataset()
623
+ # print(promptBook.columns)
624
+
625
+ # initialize selected_dict
626
+ if 'selected_dict' not in st.session_state:
627
+ st.session_state['selected_dict'] = {}
628
+
629
+ app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
630
+ app.app()
631
+
632
+ # components.html(
633
+ # """
634
+ # <script>
635
+ # var iframe = window.parent.document.querySelector('[title="streamlit_agraph.agraph"]');
636
+ # console.log(iframe);
637
+ # var targetElement = iframe.contentDocument.querySelector('div.vis-network div.vis-navigation div.vis-button.vis-zoomExtends');
638
+ # console.log(targetElement);
639
+ # targetElement.style.background-image = "url(https://www.flaticon.com/free-icon-font/menu-burger_3917215?related_id=3917215#)";
640
+ # </script>
641
+ # """,
642
+ # # unsafe_allow_html=True,
643
+ # )
Archive/agraphTest.py DELETED
@@ -1,170 +0,0 @@
1
- import os
2
-
3
- import streamlit as st
4
- import torch
5
- import pandas as pd
6
- import numpy as np
7
-
8
- from datasets import load_dataset, Dataset, load_from_disk
9
- from huggingface_hub import login
10
- from streamlit_agraph import agraph, Node, Edge, Config
11
- from sklearn.manifold import TSNE
12
-
13
-
14
- @st.cache_data
15
- def load_hf_dataset():
16
- # login to huggingface
17
- login(token=os.environ.get("HF_TOKEN"))
18
-
19
- # load from huggingface
20
- roster = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Roster', split='train'))
21
- promptBook = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Metadata', split='train'))
22
-
23
- # process dataset
24
- roster = roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
25
- 'model_download_count']].drop_duplicates().reset_index(drop=True)
26
-
27
- # add 'custom_score_weights' column to promptBook if not exist
28
- if 'weighted_score_sum' not in promptBook.columns:
29
- promptBook.loc[:, 'weighted_score_sum'] = 0
30
-
31
- # merge roster and promptbook
32
- promptBook = promptBook.merge(roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']],
33
- on=['model_id', 'modelVersion_id'], how='left')
34
-
35
- # add column to record current row index
36
- promptBook.loc[:, 'row_idx'] = promptBook.index
37
-
38
- return roster, promptBook
39
-
40
-
41
- @st.cache_data
42
- def calc_tsne(prompt_id):
43
- print('==> loading feats')
44
- feats = {}
45
- for pt in os.listdir('../data/feats'):
46
- if pt.split('.')[-1] == 'pt' and pt.split('.')[0].isdigit():
47
- feats[pt.split('.')[0]] = torch.load(os.path.join('../data/feats', pt))
48
-
49
- print('==> applying t-SNE')
50
- # apply t-SNE to entries in each feat in feats to get 2D coordinates
51
- tsne = TSNE(n_components=2, random_state=0)
52
- # for k, v in tqdm(feats.items()):
53
- # feats[k]['tsne'] = tsne.fit_transform(v['all'].numpy())
54
- # prompt_id = '90'
55
- feats[prompt_id]['tsne'] = tsne.fit_transform(feats[prompt_id]['all'].numpy())
56
-
57
- feats_df = pd.DataFrame(feats[prompt_id]['tsne'], columns=['x', 'y'])
58
- feats_df['prompt_id'] = prompt_id
59
-
60
- keys = []
61
- for k in feats[prompt_id].keys():
62
- if k != 'all' and k != 'tsne':
63
- keys.append(int(k.item()))
64
-
65
- feats_df['modelVersion_id'] = keys
66
-
67
-
68
- return feats_df
69
-
70
- # print(feats[prompt_id]['tsne'])
71
-
72
-
73
- if __name__ == '__main__':
74
- st.set_page_config(layout="wide")
75
-
76
- # load dataset
77
- roster, promptBook = load_hf_dataset()
78
- # prompt_id = '20'
79
-
80
- with st.sidebar:
81
- st.write('## Select Prompt')
82
- prompts = promptBook['prompt_id'].unique().tolist()
83
- # sort prompts by prompt_id
84
- prompts.sort()
85
- prompt_id = st.selectbox('Select Prompt', prompts, index=0)
86
- physics = st.checkbox('Enable Physics')
87
-
88
- feats_df = calc_tsne(str(prompt_id))
89
-
90
- # keys = []
91
- # for k in feats[prompt_id].keys():
92
- # if k != 'all' and k != 'tsne':
93
- # keys.append(int(k.item()))
94
-
95
- # print(keys)
96
-
97
- data = []
98
- for idx in feats_df.index:
99
- modelVersion_id = feats_df.loc[idx, 'modelVersion_id']
100
- image_id = promptBook[(promptBook['modelVersion_id'] == modelVersion_id) & (
101
- promptBook['prompt_id'] == int(prompt_id))].reset_index(drop=True).loc[0, 'image_id']
102
- image_url = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{image_id}.png"
103
- scale = 50
104
- data.append((feats_df.loc[idx, 'x'] * scale, feats_df.loc[idx, 'y'] * scale, image_url))
105
-
106
- image_size = promptBook[(promptBook['image_id'] == image_id)].reset_index(drop=True).loc[0, 'size'].split('x')
107
-
108
- nodes = []
109
- edges = []
110
-
111
- for d in data:
112
- nodes.append( Node(id=d[2],
113
- # label=str(items.loc[idx, 'model_name']),
114
- size=20,
115
- shape="image",
116
- image=d[2],
117
- x=[d[0]],
118
- y=[d[1]],
119
- fixed=False if physics else True,
120
- color={'background': '#00000', 'border': '#ffffff'},
121
- shadow={'enabled': True, 'color': 'rgba(0,0,0,0.4)', 'size': 10, 'x': 1, 'y': 1},
122
- # borderWidth=1,
123
- # shapeProperties={'useBorderWithImage': True},
124
- )
125
- )
126
-
127
-
128
- # nodes.append( Node(id="Spiderman",
129
- # label="Peter Parker",
130
- # size=25,
131
- # shape="circularImage",
132
- # image="http://marvel-force-chart.surge.sh/marvel_force_chart_img/top_spiderman.png")
133
- # ) # includes **kwargs
134
- # nodes.append( Node(id="Captain_Marvel",
135
- # label="Carol Danvers",
136
- # fixed=True,
137
- # size=25,
138
- # shape="circularImage",
139
- # image="http://marvel-force-chart.surge.sh/marvel_force_chart_img/top_captainmarvel.png")
140
- # )
141
- # edges.append( Edge(source="Captain_Marvel",
142
- # label="friend_of",
143
- # target="Spiderman",
144
- # length=200,
145
- # # **kwargs
146
- # )
147
- # )
148
- #
149
- config = Config(width='100%',
150
- height=800,
151
- directed=True,
152
- physics=physics,
153
- hierarchical=False,
154
- # **kwargs
155
- )
156
-
157
- cols = st.columns([3, 1], gap='large')
158
-
159
- with cols[0]:
160
- return_value = agraph(nodes=nodes,
161
- edges=edges,
162
- config=config)
163
-
164
- # st.write(return_value)
165
-
166
- with cols[1]:
167
- try:
168
- st.image(return_value, use_column_width=True)
169
- except:
170
- st.write('No image selected')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Archive/bokehTest.py DELETED
@@ -1,182 +0,0 @@
1
- import os
2
-
3
- import streamlit as st
4
- import torch
5
- import pandas as pd
6
- import numpy as np
7
- import requests
8
-
9
- from bokeh.plotting import figure, show
10
- from bokeh.models import HoverTool, ColumnDataSource, CustomJSHover
11
- from bokeh.embed import file_html
12
- from bokeh.resources import CDN # Import CDN here
13
- from datasets import load_dataset, Dataset, load_from_disk
14
- from huggingface_hub import login
15
- from sklearn.manifold import TSNE
16
- from tqdm import tqdm
17
-
18
-
19
- @st.cache_data
20
- def load_hf_dataset():
21
- # login to huggingface
22
- login(token=os.environ.get("HF_TOKEN"))
23
-
24
- # load from huggingface
25
- roster = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Roster', split='train'))
26
- promptBook = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Metadata', split='train'))
27
-
28
- # process dataset
29
- roster = roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
30
- 'model_download_count']].drop_duplicates().reset_index(drop=True)
31
-
32
- # add 'custom_score_weights' column to promptBook if not exist
33
- if 'weighted_score_sum' not in promptBook.columns:
34
- promptBook.loc[:, 'weighted_score_sum'] = 0
35
-
36
- # merge roster and promptbook
37
- promptBook = promptBook.merge(roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']],
38
- on=['model_id', 'modelVersion_id'], how='left')
39
-
40
- # add column to record current row index
41
- promptBook.loc[:, 'row_idx'] = promptBook.index
42
-
43
- return roster, promptBook
44
-
45
- def show_with_bokeh(data, streamlit=False):
46
- # Extract x, y coordinates and image URLs
47
- x_coords, y_coords, image_urls = zip(*data)
48
-
49
- # Create a ColumnDataSource
50
- source = ColumnDataSource(data=dict(x=x_coords, y=y_coords, image=image_urls))
51
-
52
- # Create a figure
53
- p = figure(width=800, height=600)
54
-
55
- # Add scatter plot
56
- scatter = p.scatter(x='x', y='y', size=20, source=source)
57
-
58
- # Define hover tool
59
- hover = HoverTool()
60
- # hover.tooltips = """
61
- # <div>
62
- # <iframe src="@image" width="512" height="512"></iframe>
63
- # </div>
64
- # """
65
- # hover.formatters = {'@image': CustomJSHover(code="""
66
- # const index = cb_data.index;
67
- # const url = cb_data.source.data['image'][index];
68
- # return '<iframe src="' + url + '" width="512" height="512"></iframe>';
69
- # """)}
70
-
71
- hover.tooltips = """
72
- <div>
73
- <img src="@image" style='object-fit: contain'; height=100%">
74
- </div>
75
- """
76
- hover.formatters = {'@image': CustomJSHover(code="""
77
- const index = cb_data.index;
78
- const url = cb_data.source.data['image'][index];
79
- return '<img src="' + url + '">';
80
- """)}
81
-
82
- p.add_tools(hover)
83
-
84
- # Generate HTML with the plot
85
- html = file_html(p, CDN, "Interactive Scatter Plot with Hover Images")
86
-
87
- # Save the HTML file or show it
88
- # with open("scatter_plot_with_hover_images.html", "w") as f:
89
- # f.write(html)
90
-
91
- if streamlit:
92
- st.bokeh_chart(p, use_container_width=True)
93
- else:
94
- show(p)
95
-
96
-
97
- def show_with_bokeh_2(data, image_size=[40, 40], streamlit=False):
98
- # Extract x, y coordinates and image URLs
99
- x_coords, y_coords, image_urls = zip(*data)
100
-
101
- # Create a ColumnDataSource
102
- source = ColumnDataSource(data=dict(x=x_coords, y=y_coords, image=image_urls))
103
-
104
- # Create a figure
105
- p = figure(width=800, height=600, aspect_ratio=1.0)
106
-
107
- # Add image glyphs
108
- # image_size = 40 # Adjust this size as needed
109
- scale = 0.1
110
- image_size = [int(image_size[0])*scale, int(image_size[1])*scale]
111
- print(image_size)
112
- p.image_url(url='image', x='x', y='y', source=source, w=image_size[0], h=image_size[1], anchor="center")
113
-
114
- # Define hover tool
115
- hover = HoverTool()
116
- hover.tooltips = """
117
- <div>
118
- <img src="@image" style='object-fit: contain'; height=100%'">
119
- </div>
120
- """
121
- p.add_tools(hover)
122
-
123
- # Generate HTML with the plot
124
- html = file_html(p, CDN, "Scatter Plot with Images")
125
-
126
- # Save the HTML file or show it
127
- # with open("scatter_plot_with_images.html", "w") as f:
128
- # f.write(html)
129
-
130
- if streamlit:
131
- st.bokeh_chart(p, use_container_width=True)
132
- else:
133
- show(p)
134
-
135
-
136
- if __name__ == '__main__':
137
- # load dataset
138
- roster, promptBook = load_hf_dataset()
139
-
140
- print('==> loading feats')
141
- feats = {}
142
- for pt in os.listdir('../data/feats'):
143
- if pt.split('.')[-1] == 'pt' and pt.split('.')[0].isdigit():
144
- feats[pt.split('.')[0]] = torch.load(os.path.join('../data/feats', pt))
145
-
146
- print('==> applying t-SNE')
147
- # apply t-SNE to entries in each feat in feats to get 2D coordinates
148
- tsne = TSNE(n_components=2, random_state=0)
149
- # for k, v in tqdm(feats.items()):
150
- # feats[k]['tsne'] = tsne.fit_transform(v['all'].numpy())
151
- prompt_id = '49'
152
- feats[prompt_id]['tsne'] = tsne.fit_transform(feats[prompt_id]['all'].numpy())
153
-
154
- print(feats[prompt_id]['tsne'])
155
-
156
- keys = []
157
- for k in feats[prompt_id].keys():
158
- if k != 'all' and k != 'tsne':
159
- keys.append(int(k.item()))
160
-
161
- print(keys)
162
-
163
- data = []
164
- for idx in range(len(keys)):
165
- modelVersion_id = keys[idx]
166
- image_id = promptBook[(promptBook['modelVersion_id'] == modelVersion_id) & (promptBook['prompt_id'] == int(prompt_id))].reset_index(drop=True).loc[0, 'image_id']
167
- image_url = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{image_id}.png"
168
- scale = 50
169
- data.append((feats[prompt_id]['tsne'][idx][0]*scale, feats[prompt_id]['tsne'][idx][1]*scale, image_url))
170
-
171
- image_size = promptBook[(promptBook['image_id'] == image_id)].reset_index(drop=True).loc[0, 'size'].split('x')
172
-
173
- # # Sample data: (x, y) coordinates and corresponding image URLs
174
- # data = [
175
- # (2, 5, "https://www.crunchyroll.com/imgsrv/display/thumbnail/480x720/catalog/crunchyroll/669dae5dbea3d93bb5f1012078501976.jpeg"),
176
- # (4, 8, "https://i.pinimg.com/originals/40/6d/38/406d38957bc4fd12f34c5dfa3d73b86d.jpg"),
177
- # (7, 3, "https://i.pinimg.com/550x/76/27/d2/7627d227adc6fb5fb6662ebfb9d82d7e.jpg"),
178
- # # Add more data points and image URLs
179
- # ]
180
-
181
- # show_with_bokeh(data, streamlit=True)
182
- show_with_bokeh_2(data, image_size=image_size, streamlit=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Archive/optimization.py DELETED
@@ -1,37 +0,0 @@
1
- import numpy as np
2
- from scipy.optimize import minimize, differential_evolution
3
-
4
-
5
- # Define the function y = x_1*w_1 + x_2*w_2 + x_3*w_3
6
- def objective_function(w_indices):
7
- x_1 = x_1_values[int(w_indices[0])]
8
- x_2 = x_2_values[int(w_indices[1])]
9
- x_3 = x_3_values[int(w_indices[2])]
10
- return - (x_1 * w_indices[3] + x_2 * w_indices[4] + x_3 * w_indices[5]) # Use w_indices to get w_1, w_2, w_3
11
-
12
-
13
- if __name__ == '__main__':
14
- # Given sets of discrete values for x_1, x_2, and x_3
15
- x_1_values = [1, 2, 3, 5, 6]
16
- x_2_values = [0, 5, 7, 2, 1]
17
- x_3_values = [3, 7, 4, 5, 2]
18
-
19
- # Perform differential evolution optimization with integer variables
20
- # bounds = [(0, len(x_1_values) - 2), (0, len(x_2_values) - 1), (0, len(x_3_values) - 1), (-1, 1), (-1, 1), (-1, 1)]
21
- bounds = [(3, 4), (3, 4), (3, 4), (-1, 1), (-1, 1), (-1, 1)]
22
- result = differential_evolution(objective_function, bounds)
23
-
24
- # Get the optimal indices of x_1, x_2, and x_3
25
- x_1_index, x_2_index, x_3_index, w_1_opt, w_2_opt, w_3_opt = result.x
26
-
27
- # Calculate the peak point (x_1, x_2, x_3) corresponding to the optimal indices
28
- x_1_peak = x_1_values[int(x_1_index)]
29
- x_2_peak = x_2_values[int(x_2_index)]
30
- x_3_peak = x_3_values[int(x_3_index)]
31
-
32
- # Print the results
33
- print("Optimal w_1:", w_1_opt)
34
- print("Optimal w_2:", w_2_opt)
35
- print("Optimal w_3:", w_3_opt)
36
- print("Peak Point (x_1, x_2, x_3):", (x_1_peak, x_2_peak, x_3_peak))
37
- print("Maximum Value of y:", -result.fun) # Use negative sign as we previously used to maximize
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Archive/optimization2.py DELETED
@@ -1,40 +0,0 @@
1
- import numpy as np
2
- from scipy.optimize import minimize
3
-
4
- if __name__ == '__main__':
5
-
6
- # Given subset of m values for x_1, x_2, and x_3
7
- x1_subset = [2, 3, 4]
8
- x2_subset = [0, 1]
9
- x3_subset = [5, 6, 7]
10
-
11
- # Full set of possible values for x_1, x_2, and x_3
12
- x1_full = [1, 2, 3, 4, 5]
13
- x2_full = [0, 1, 2, 3, 4, 5]
14
- x3_full = [3, 5, 7]
15
-
16
- # Define the objective function for quantile-based ranking
17
- def objective_function(w):
18
- y_subset = [x1 * w[0] + x2 * w[1] + x3 * w[2] for x1, x2, x3 in zip(x1_subset, x2_subset, x3_subset)]
19
- y_full_set = [x1 * w[0] + x2 * w[1] + x3 * w[2] for x1 in x1_full for x2 in x2_full for x3 in x3_full]
20
-
21
- # Calculate the 90th percentile of y values for the full set
22
- y_full_set_90th_percentile = np.percentile(y_full_set, 90)
23
-
24
- # Maximize the difference between the 90th percentile of the subset and the 90th percentile of the full set
25
- return - min(y_subset) + y_full_set_90th_percentile
26
-
27
-
28
- # Bounds for w_1, w_2, and w_3 (-1 to 1)
29
- bounds = [(-1, 1), (-1, 1), (-1, 1)]
30
-
31
- # Perform bounded optimization to find the values of w_1, w_2, and w_3 that maximize the objective function
32
- result = minimize(objective_function, np.zeros(3), method='TNC', bounds=bounds)
33
-
34
- # Get the optimal values of w_1, w_2, and w_3
35
- w_1_opt, w_2_opt, w_3_opt = result.x
36
-
37
- # Print the results
38
- print("Optimal w_1:", w_1_opt)
39
- print("Optimal w_2:", w_2_opt)
40
- print("Optimal w_3:", w_3_opt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/Gallery.py CHANGED
@@ -157,29 +157,29 @@ class GalleryApp:
157
  # save latest threshold
158
  st.session_state.score_weights[3] = nsfw_threshold
159
 
160
- # draw a distribution histogram
161
- if sort_type == 'Scores':
162
- try:
163
- with st.expander('Show score distribution histogram and select score range'):
164
- st.write('**Score distribution histogram**')
165
- chart_space = st.container()
166
- # st.write('Select the range of scores to show')
167
- hist_data = pd.DataFrame(items[sort_by])
168
- mini = hist_data[sort_by].min().item()
169
- mini = mini//0.1 * 0.1
170
- maxi = hist_data[sort_by].max().item()
171
- maxi = maxi//0.1 * 0.1 + 0.1
172
- st.write('**Select the range of scores to show**')
173
- r = st.slider('Select the range of scores to show', min_value=mini, max_value=maxi, value=(mini, maxi), step=0.05, label_visibility='collapsed')
174
- with chart_space:
175
- st.altair_chart(altair_histogram(hist_data, sort_by, r[0], r[1]), use_container_width=True)
176
- # event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
177
- # r = event_dict.get(sort_by)
178
- if r:
179
- items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
180
- # st.write(r)
181
- except:
182
- pass
183
 
184
  display_options = st.columns([1, 4])
185
 
@@ -208,25 +208,24 @@ class GalleryApp:
208
 
209
  return items, info, col_num
210
 
211
- def sidebar(self):
212
  with st.sidebar:
213
- prompt_tags = self.promptBook['tag'].unique()
214
- # sort tags by alphabetical order
215
- prompt_tags = np.sort(prompt_tags)[::1]
 
 
 
 
 
 
 
 
216
 
217
- tag = st.selectbox('Select a tag', prompt_tags, index=5)
218
 
219
- items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
220
 
221
- prompts = np.sort(items['prompt'].unique())[::1]
222
-
223
- selected_prompt = st.selectbox('Select prompt', prompts, index=3)
224
-
225
- mode = st.radio('Select a mode', ['Gallery', 'Graph'], horizontal=True, index=1)
226
-
227
- items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
228
- prompt_id = items['prompt_id'].unique()[0]
229
- note = items['note'].unique()[0]
230
 
231
  # show source
232
  if isinstance(note, str):
@@ -260,49 +259,66 @@ class GalleryApp:
260
  except:
261
  pass
262
 
263
- return prompt_tags, tag, prompt_id, items, mode
264
 
265
- def app(self):
266
- st.title('Model Visualization and Retrieval')
267
- st.write('This is a gallery of images generated by the models')
268
-
269
- prompt_tags, tag, prompt_id, items, mode = self.sidebar()
270
- # items, info, col_num = self.selection_panel(items)
271
-
272
- # subset = st.radio('Select a subset', ['All', 'Selected Only'], index=0, horizontal=True)
273
- # try:
274
- # if subset == 'Selected Only':
275
- # items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
276
- # except:
277
- # pass
278
-
279
- # add safety check for some prompts
280
- safety_check = True
281
- unsafe_prompts = {}
282
- # initialize unsafe prompts
283
- for prompt_tag in prompt_tags:
284
- unsafe_prompts[prompt_tag] = []
285
- # manually add unsafe prompts
286
- unsafe_prompts['world knowledge'] = [83]
287
- unsafe_prompts['abstract'] = [1, 3]
288
-
289
- if int(prompt_id.item()) in unsafe_prompts[tag]:
290
- st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
291
- safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
292
-
293
- if safety_check:
294
- if mode == 'Gallery':
295
- self.gallery_mode(prompt_id, items)
296
- elif mode == 'Graph':
297
- self.graph_mode(prompt_id, items)
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
  def graph_mode(self, prompt_id, items):
301
  graph_cols = st.columns([3, 1])
302
- prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}",
303
- disabled=False, key=f'{prompt_id}')
304
- if prompt:
305
- switch_page("ranking")
306
 
307
  with graph_cols[0]:
308
  graph_space = st.empty()
@@ -366,20 +382,20 @@ class GalleryApp:
366
  def gallery_mode(self, prompt_id, items):
367
  items, info, col_num = self.selection_panel(items)
368
 
369
- if 'selected_dict' in st.session_state:
370
- # st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
371
- dynamic_weight_options = ['Grid Search', 'SVM', 'Greedy']
372
- dynamic_weight_panel = st.columns(len(dynamic_weight_options))
373
-
374
- if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
375
- btn_disable = False
376
- else:
377
- btn_disable = True
378
-
379
- for i in range(len(dynamic_weight_options)):
380
- method = dynamic_weight_options[i]
381
- with dynamic_weight_panel[i]:
382
- btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, method))
383
 
384
  prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}", disabled=False, key=f'{prompt_id}')
385
  if prompt:
@@ -387,32 +403,28 @@ class GalleryApp:
387
 
388
  with st.form(key=f'{prompt_id}'):
389
  # buttons = st.columns([1, 1, 1])
390
- buttons_space = st.columns([1, 1, 1, 1])
391
  gallery_space = st.empty()
392
 
393
  with buttons_space[0]:
394
- continue_btn = st.form_submit_button('Confirm Selection', use_container_width=True, type='primary')
395
  if continue_btn:
396
- self.submit_actions('Continue', prompt_id)
 
397
 
398
  with buttons_space[1]:
399
- select_btn = st.form_submit_button('Select All', use_container_width=True)
400
- if select_btn:
401
- self.submit_actions('Select', prompt_id)
402
-
403
- with buttons_space[2]:
404
  deselect_btn = st.form_submit_button('Deselect All', use_container_width=True)
405
  if deselect_btn:
406
  self.submit_actions('Deselect', prompt_id)
407
 
408
- with buttons_space[3]:
409
  refresh_btn = st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True)
410
 
411
  with gallery_space.container():
412
  with st.spinner('Loading images...'):
413
  self.gallery_standard(items, col_num, info)
414
 
415
- st.info("Don't forget to scroll back to top and click the 'Confirm Selection' button to save your selection!!!")
416
 
417
 
418
 
 
157
  # save latest threshold
158
  st.session_state.score_weights[3] = nsfw_threshold
159
 
160
+ # # draw a distribution histogram
161
+ # if sort_type == 'Scores':
162
+ # try:
163
+ # with st.expander('Show score distribution histogram and select score range'):
164
+ # st.write('**Score distribution histogram**')
165
+ # chart_space = st.container()
166
+ # # st.write('Select the range of scores to show')
167
+ # hist_data = pd.DataFrame(items[sort_by])
168
+ # mini = hist_data[sort_by].min().item()
169
+ # mini = mini//0.1 * 0.1
170
+ # maxi = hist_data[sort_by].max().item()
171
+ # maxi = maxi//0.1 * 0.1 + 0.1
172
+ # st.write('**Select the range of scores to show**')
173
+ # r = st.slider('Select the range of scores to show', min_value=mini, max_value=maxi, value=(mini, maxi), step=0.05, label_visibility='collapsed')
174
+ # with chart_space:
175
+ # st.altair_chart(altair_histogram(hist_data, sort_by, r[0], r[1]), use_container_width=True)
176
+ # # event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
177
+ # # r = event_dict.get(sort_by)
178
+ # if r:
179
+ # items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
180
+ # # st.write(r)
181
+ # except:
182
+ # pass
183
 
184
  display_options = st.columns([1, 4])
185
 
 
208
 
209
  return items, info, col_num
210
 
211
+ def sidebar(self, items, prompt_id, note):
212
  with st.sidebar:
213
+ # prompt_tags = self.promptBook['tag'].unique()
214
+ # # sort tags by alphabetical order
215
+ # prompt_tags = np.sort(prompt_tags)[::1]
216
+ #
217
+ # tag = st.selectbox('Select a tag', prompt_tags, index=5)
218
+ #
219
+ # items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
220
+ #
221
+ # prompts = np.sort(items['prompt'].unique())[::1]
222
+ #
223
+ # selected_prompt = st.selectbox('Select prompt', prompts, index=3)
224
 
225
+ # mode = st.radio('Select a mode', ['Gallery', 'Graph'], horizontal=True, index=1)
226
 
227
+ # items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
228
 
 
 
 
 
 
 
 
 
 
229
 
230
  # show source
231
  if isinstance(note, str):
 
259
  except:
260
  pass
261
 
262
+ # return prompt_tags, tag, prompt_id, items
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
+ def app(self):
266
+ # st.title('Model Visualization and Retrieval')
267
+ # st.write('This is a gallery of images generated by the models')
268
+
269
+ # build the tabular view
270
+ prompt_tags = self.promptBook['tag'].unique()
271
+ # sort tags by alphabetical order
272
+ prompt_tags = np.sort(prompt_tags)[::1].tolist()
273
+
274
+ tabs = st.tabs(prompt_tags)
275
+ with st.spinner('Loading...'):
276
+ for i in range(len(prompt_tags)):
277
+ with tabs[i]:
278
+ tag = prompt_tags[i]
279
+ items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
280
+
281
+ prompts = np.sort(items['prompt'].unique())[::1]
282
+
283
+ subset_selector = st.columns([3, 1])
284
+ with subset_selector[0]:
285
+ selected_prompt = st.selectbox('Select prompt', prompts, index=3)
286
+ with subset_selector[1]:
287
+ subset = st.selectbox('Select a subset', ['All', 'Selected Only'], index=0, key=f'subset_{selected_prompt}')
288
+ items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
289
+ prompt_id = items['prompt_id'].unique()[0]
290
+ note = items['note'].unique()[0]
291
+
292
+ # add safety check for some prompts
293
+ safety_check = True
294
+ unsafe_prompts = {}
295
+ # initialize unsafe prompts
296
+ for prompt_tag in prompt_tags:
297
+ unsafe_prompts[prompt_tag] = []
298
+ # manually add unsafe prompts
299
+ unsafe_prompts['world knowledge'] = [83]
300
+ unsafe_prompts['abstract'] = [1, 3]
301
+
302
+ if int(prompt_id.item()) in unsafe_prompts[tag]:
303
+ st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
304
+ safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
305
+
306
+ if safety_check:
307
+
308
+ # if subset == 'Selected Only' and 'selected_dict' in st.session_state:
309
+ # items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
310
+ # self.gallery_mode(prompt_id, items)
311
+ # else:
312
+ self.graph_mode(prompt_id, items)
313
+
314
+ self.sidebar(items, prompt_id, note)
315
 
316
  def graph_mode(self, prompt_id, items):
317
  graph_cols = st.columns([3, 1])
318
+ # prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}",
319
+ # disabled=False, key=f'{prompt_id}')
320
+ # if prompt:
321
+ # switch_page("ranking")
322
 
323
  with graph_cols[0]:
324
  graph_space = st.empty()
 
382
  def gallery_mode(self, prompt_id, items):
383
  items, info, col_num = self.selection_panel(items)
384
 
385
+ # if 'selected_dict' in st.session_state:
386
+ # # st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
387
+ # dynamic_weight_options = ['Grid Search', 'SVM', 'Greedy']
388
+ # dynamic_weight_panel = st.columns(len(dynamic_weight_options))
389
+ #
390
+ # if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
391
+ # btn_disable = False
392
+ # else:
393
+ # btn_disable = True
394
+ #
395
+ # for i in range(len(dynamic_weight_options)):
396
+ # method = dynamic_weight_options[i]
397
+ # with dynamic_weight_panel[i]:
398
+ # btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, method))
399
 
400
  prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}", disabled=False, key=f'{prompt_id}')
401
  if prompt:
 
403
 
404
  with st.form(key=f'{prompt_id}'):
405
  # buttons = st.columns([1, 1, 1])
406
+ buttons_space = st.columns([1, 1, 1])
407
  gallery_space = st.empty()
408
 
409
  with buttons_space[0]:
410
+ continue_btn = st.form_submit_button('Proceed selections to ranking', use_container_width=True, type='primary')
411
  if continue_btn:
412
+ # self.submit_actions('Continue', prompt_id)
413
+ switch_page("ranking")
414
 
415
  with buttons_space[1]:
 
 
 
 
 
416
  deselect_btn = st.form_submit_button('Deselect All', use_container_width=True)
417
  if deselect_btn:
418
  self.submit_actions('Deselect', prompt_id)
419
 
420
+ with buttons_space[2]:
421
  refresh_btn = st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True)
422
 
423
  with gallery_space.container():
424
  with st.spinner('Loading images...'):
425
  self.gallery_standard(items, col_num, info)
426
 
427
+ # st.info("Don't forget to scroll back to top and click the 'Confirm Selection' button to save your selection!!!")
428
 
429
 
430