Ricercar commited on
Commit
97b4d0f
1 Parent(s): 49e2601

important bug fix for image selection

Browse files
Archive/optimization.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.optimize import minimize, differential_evolution
3
+
4
+
5
+ # Define the function y = x_1*w_1 + x_2*w_2 + x_3*w_3
6
+ def objective_function(w_indices):
7
+ x_1 = x_1_values[int(w_indices[0])]
8
+ x_2 = x_2_values[int(w_indices[1])]
9
+ x_3 = x_3_values[int(w_indices[2])]
10
+ return - (x_1 * w_indices[3] + x_2 * w_indices[4] + x_3 * w_indices[5]) # Use w_indices to get w_1, w_2, w_3
11
+
12
+
13
+ if __name__ == '__main__':
14
+ # Given sets of discrete values for x_1, x_2, and x_3
15
+ x_1_values = [1, 2, 3, 5, 6]
16
+ x_2_values = [0, 5, 7, 2, 1]
17
+ x_3_values = [3, 7, 4, 5, 2]
18
+
19
+ # Perform differential evolution optimization with integer variables
20
+ # bounds = [(0, len(x_1_values) - 2), (0, len(x_2_values) - 1), (0, len(x_3_values) - 1), (-1, 1), (-1, 1), (-1, 1)]
21
+ bounds = [(3, 4), (3, 4), (3, 4), (-1, 1), (-1, 1), (-1, 1)]
22
+ result = differential_evolution(objective_function, bounds)
23
+
24
+ # Get the optimal indices of x_1, x_2, and x_3
25
+ x_1_index, x_2_index, x_3_index, w_1_opt, w_2_opt, w_3_opt = result.x
26
+
27
+ # Calculate the peak point (x_1, x_2, x_3) corresponding to the optimal indices
28
+ x_1_peak = x_1_values[int(x_1_index)]
29
+ x_2_peak = x_2_values[int(x_2_index)]
30
+ x_3_peak = x_3_values[int(x_3_index)]
31
+
32
+ # Print the results
33
+ print("Optimal w_1:", w_1_opt)
34
+ print("Optimal w_2:", w_2_opt)
35
+ print("Optimal w_3:", w_3_opt)
36
+ print("Peak Point (x_1, x_2, x_3):", (x_1_peak, x_2_peak, x_3_peak))
37
+ print("Maximum Value of y:", -result.fun) # Use negative sign as we previously used to maximize
Archive/optimization2.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.optimize import minimize
3
+
4
+ if __name__ == '__main__':
5
+
6
+ # Given subset of m values for x_1, x_2, and x_3
7
+ x1_subset = [2, 3, 4]
8
+ x2_subset = [0, 1]
9
+ x3_subset = [5, 6, 7]
10
+
11
+ # Full set of possible values for x_1, x_2, and x_3
12
+ x1_full = [1, 2, 3, 4, 5]
13
+ x2_full = [0, 1, 2, 3, 4, 5]
14
+ x3_full = [3, 5, 7]
15
+
16
+ # Define the objective function for quantile-based ranking
17
+ def objective_function(w):
18
+ y_subset = [x1 * w[0] + x2 * w[1] + x3 * w[2] for x1, x2, x3 in zip(x1_subset, x2_subset, x3_subset)]
19
+ y_full_set = [x1 * w[0] + x2 * w[1] + x3 * w[2] for x1 in x1_full for x2 in x2_full for x3 in x3_full]
20
+
21
+ # Calculate the 90th percentile of y values for the full set
22
+ y_full_set_90th_percentile = np.percentile(y_full_set, 90)
23
+
24
+ # Maximize the difference between the 90th percentile of the subset and the 90th percentile of the full set
25
+ return - min(y_subset) + y_full_set_90th_percentile
26
+
27
+
28
+ # Bounds for w_1, w_2, and w_3 (-1 to 1)
29
+ bounds = [(-1, 1), (-1, 1), (-1, 1)]
30
+
31
+ # Perform bounded optimization to find the values of w_1, w_2, and w_3 that maximize the objective function
32
+ result = minimize(objective_function, np.zeros(3), method='TNC', bounds=bounds)
33
+
34
+ # Get the optimal values of w_1, w_2, and w_3
35
+ w_1_opt, w_2_opt, w_3_opt = result.x
36
+
37
+ # Print the results
38
+ print("Optimal w_1:", w_1_opt)
39
+ print("Optimal w_2:", w_2_opt)
40
+ print("Optimal w_3:", w_3_opt)
Archive/test_form.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ def grid(col=3, row=4, name='grid1'):
5
+ cols = st.columns(col)
6
+ for i in range(row):
7
+ for j in range(col):
8
+ with cols[j]:
9
+ value = st.session_state.checked_dic[name].get(f"{name}_{i*col+j}", False)
10
+
11
+ check = st.checkbox(f"{i*col+j}", key=f"{name}_{i*col+j}", value=value)
12
+ if check:
13
+ st.session_state.checked_dic[name][f"{name}_{i*col+j}"] = True
14
+ else:
15
+ st.session_state.checked_dic[name][f"{name}_{i*col+j}"] = False
16
+
17
+
18
+ def on_click():
19
+ for key in st.session_state:
20
+ if st.session_state[key] and key[-1].isdigit():
21
+ st.write(key)
22
+ # for key in st.session_state.checked_dic[name]:
23
+ # if st.session_state.checked_dic[name][key]:
24
+ # st.write(key)
25
+
26
+
27
+
28
+ if __name__ == "__main__":
29
+ if 'checked_dic' not in st.session_state:
30
+ st.session_state.checked_dic = {'grid1': {}, 'grid2': {}}
31
+
32
+ name = st.selectbox('Select a grid', ['grid1', 'grid2'])
33
+
34
+ with st.form(f"{name}_form"):
35
+ grid(name=name)
36
+ submit_button = st.form_submit_button("Submit", on_click=on_click)
37
+
38
+
39
+
Home.py CHANGED
@@ -38,6 +38,8 @@ def logout():
38
 
39
 
40
  if __name__ == '__main__':
 
 
41
  st.set_page_config(page_title="Login", page_icon="🏠", layout="wide")
42
  st.write('A Research by MAPS Lab, NYU Shanghai')
43
  st.title("Personalized Model Coffer")
 
38
 
39
 
40
  if __name__ == '__main__':
41
+ # print(st.source_util.get_pages('Home.py'))
42
+
43
  st.set_page_config(page_title="Login", page_icon="🏠", layout="wide")
44
  st.write('A Research by MAPS Lab, NYU Shanghai')
45
  st.title("Personalized Model Coffer")
pages/Gallery.py CHANGED
@@ -1,14 +1,16 @@
1
- import streamlit as st
 
 
 
2
  import numpy as np
3
  import pandas as pd
4
- import glob
 
 
5
  from datasets import load_dataset, Dataset, load_from_disk
6
  from huggingface_hub import login
7
- import os
8
- import requests
9
- from bs4 import BeautifulSoup
10
- import altair as alt
11
  from streamlit_extras.switch_page_button import switch_page
 
12
 
13
  SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'msq_score', 'pop': 'model_download_count'}
14
 
@@ -62,20 +64,25 @@ class GalleryApp:
62
  # handel checkbox information
63
  prompt_id = items.iloc[idx + j]['prompt_id']
64
  modelVersion_id = items.iloc[idx + j]['modelVersion_id']
 
65
  check_init = True if modelVersion_id in st.session_state.selected_dict.get(prompt_id, []) else False
66
 
 
 
67
  # show checkbox
68
- checked = st.checkbox('Select', key=f'select_{idx + j}', value=check_init)
69
- if checked:
70
- if prompt_id not in st.session_state.selected_dict:
71
- st.session_state.selected_dict[prompt_id] = []
72
- if modelVersion_id not in st.session_state.selected_dict[prompt_id]:
73
- st.session_state.selected_dict[prompt_id].append(modelVersion_id)
74
- else:
75
- try:
76
- st.session_state.selected_dict[prompt_id].remove(modelVersion_id)
77
- except:
78
- pass
 
 
79
 
80
  # show selected info
81
  for key in info:
@@ -186,7 +193,7 @@ class GalleryApp:
186
  # select number of columns
187
  col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
188
 
189
- return items, info, col_num
190
 
191
  def sidebar(self):
192
  with st.sidebar:
@@ -244,7 +251,7 @@ class GalleryApp:
244
  st.title('Model Visualization and Retrieval')
245
  st.write('This is a gallery of images generated by the models')
246
 
247
- prompt_tags, tag, prompt_id, items = self.sidebar()
248
 
249
  # add safety check for some prompts
250
  safety_check = True
@@ -263,8 +270,23 @@ class GalleryApp:
263
  safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
264
 
265
  if safety_check:
266
- items, info, col_num = self.selection_panel(items)
267
- # self.gallery_standard(items, col_num, info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  with st.form(key=f'{prompt_id}'):
270
  # buttons = st.columns([1, 1, 1])
@@ -293,20 +315,97 @@ class GalleryApp:
293
  with st.spinner('Loading images...'):
294
  self.gallery_standard(items, col_num, info)
295
 
 
 
 
296
  def submit_actions(self, status, prompt_id):
297
  if status == 'Select':
298
  modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
299
  st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
300
  print(st.session_state.selected_dict, 'select')
 
301
  elif status == 'Deselect':
302
  st.session_state.selected_dict[prompt_id] = []
303
  print(st.session_state.selected_dict, 'deselect')
 
304
  # self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
305
  pass
306
  elif status == 'Continue':
 
 
 
 
 
 
307
  # switch_page("ranking")
 
308
  pass
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
  @st.cache_data
312
  def load_hf_dataset():
@@ -342,6 +441,7 @@ def load_hf_dataset():
342
 
343
  if __name__ == "__main__":
344
  st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
 
345
  if 'user_id' not in st.session_state:
346
  st.warning('Please log in first.')
347
  home_btn = st.button('Go to Home Page')
 
1
+ import os
2
+ import requests
3
+
4
+ import altair as alt
5
  import numpy as np
6
  import pandas as pd
7
+ import streamlit as st
8
+
9
+ from bs4 import BeautifulSoup
10
  from datasets import load_dataset, Dataset, load_from_disk
11
  from huggingface_hub import login
 
 
 
 
12
  from streamlit_extras.switch_page_button import switch_page
13
+ from sklearn.svm import LinearSVC
14
 
15
  SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'msq_score', 'pop': 'model_download_count'}
16
 
 
64
  # handel checkbox information
65
  prompt_id = items.iloc[idx + j]['prompt_id']
66
  modelVersion_id = items.iloc[idx + j]['modelVersion_id']
67
+
68
  check_init = True if modelVersion_id in st.session_state.selected_dict.get(prompt_id, []) else False
69
 
70
+ st.write("Position: ", idx + j)
71
+
72
  # show checkbox
73
+ checked = st.checkbox('Select', key=f'select_{prompt_id}_{modelVersion_id}', value=check_init)
74
+
75
+ #
76
+ # if checked:
77
+ # if prompt_id not in st.session_state.selected_dict:
78
+ # st.session_state.selected_dict[prompt_id] = []
79
+ # if modelVersion_id not in st.session_state.selected_dict[prompt_id]:
80
+ # st.session_state.selected_dict[prompt_id].append(modelVersion_id)
81
+ # else:
82
+ # try:
83
+ # st.session_state.selected_dict[prompt_id].remove(modelVersion_id)
84
+ # except:
85
+ # pass
86
 
87
  # show selected info
88
  for key in info:
 
193
  # select number of columns
194
  col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
195
 
196
+ return items, info, col_num, preprocessor
197
 
198
  def sidebar(self):
199
  with st.sidebar:
 
251
  st.title('Model Visualization and Retrieval')
252
  st.write('This is a gallery of images generated by the models')
253
 
254
+ prompt_tags, tag, prompt_id, items= self.sidebar()
255
 
256
  # add safety check for some prompts
257
  safety_check = True
 
270
  safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
271
 
272
  if safety_check:
273
+ items, info, col_num, preprocessor = self.selection_panel(items)
274
+
275
+ # method = st.radio('Select a method to set dynamic weight', ['Grid Search', 'SVM', 'Greedy', 'Disable dynamic weight'], index=0, horizontal=True)
276
+ #
277
+ # if method != 'Disable dynamic weight':
278
+ # if len(st.session_state.selected_dict[prompt_id]) > 0:
279
+ # selected = items[
280
+ # items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(
281
+ # drop=True)
282
+ # st.session_state.score_weights[0: 3] = self.dynamic_weight(selected, items, preprocessor,
283
+ # method=method)
284
+ # # st.experimental_rerun()
285
+ #
286
+ # else:
287
+ # print('no selected models')
288
+ #
289
+ # st.write(st.session_state.selected_dict.get(prompt_id, []))
290
 
291
  with st.form(key=f'{prompt_id}'):
292
  # buttons = st.columns([1, 1, 1])
 
315
  with st.spinner('Loading images...'):
316
  self.gallery_standard(items, col_num, info)
317
 
318
+ with st.sidebar:
319
+ st.write(str(st.session_state.selected_dict[prompt_id]))
320
+
321
  def submit_actions(self, status, prompt_id):
322
  if status == 'Select':
323
  modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
324
  st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
325
  print(st.session_state.selected_dict, 'select')
326
+ st.experimental_rerun()
327
  elif status == 'Deselect':
328
  st.session_state.selected_dict[prompt_id] = []
329
  print(st.session_state.selected_dict, 'deselect')
330
+ st.experimental_rerun()
331
  # self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
332
  pass
333
  elif status == 'Continue':
334
+ st.session_state.selected_dict[prompt_id] = []
335
+ for key in st.session_state:
336
+ keys = key.split('_')
337
+ if keys[0] == 'select' and keys[1] == str(prompt_id):
338
+ if st.session_state[key]:
339
+ st.session_state.selected_dict[prompt_id].append(int(keys[2]))
340
  # switch_page("ranking")
341
+ print(st.session_state.selected_dict, 'continue')
342
  pass
343
 
344
+ def dynamic_weight(self, selected, items, preprocessor='crop', method='Grid Search'):
345
+ optimal_weight = [0, 0, 0]
346
+ if method == 'Grid Search':
347
+ # grid search method
348
+ top_ranking = len(items) * len(selected)
349
+
350
+ for clip_weight in np.arange(-1, 1, 0.1):
351
+ for mcos_weight in np.arange(-1, 1, 0.1):
352
+ for pop_weight in np.arange(-1, 1, 0.1):
353
+ weight_all = clip_weight*items[f'norm_clip_{preprocessor}'] + mcos_weight*items[f'norm_mcos_{preprocessor}'] + pop_weight*items['norm_pop']
354
+ weight_all_sorted = weight_all.sort_values(ascending=False)
355
+
356
+ weight_selected = clip_weight*selected[f'norm_clip_{preprocessor}'] + mcos_weight*selected[f'norm_mcos_{preprocessor}'] + pop_weight*selected['norm_pop']
357
+
358
+ # get the index of values of weight_selected in weight_all_sorted
359
+ rankings = []
360
+ for weight in weight_selected:
361
+ rankings.append(weight_all_sorted.index[weight_all_sorted == weight].tolist()[0])
362
+ if sum(rankings) <= top_ranking:
363
+ top_ranking = sum(rankings)
364
+ optimal_weight = [clip_weight, mcos_weight, pop_weight]
365
+ print('optimal weight:', optimal_weight)
366
+
367
+ elif method == 'SVM':
368
+ # svm method
369
+ print('start svm method')
370
+ # get residual dataframe that contains models not selected
371
+ residual = items[~items['modelVersion_id'].isin(selected['modelVersion_id'])].reset_index(drop=True)
372
+ residual = residual[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
373
+ residual = residual.to_numpy()
374
+ selected = selected[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
375
+ selected = selected.to_numpy()
376
+
377
+ y = np.concatenate((np.full((len(selected), 1), -1), np.full((len(residual), 1), 1)), axis=0).ravel()
378
+ X = np.concatenate((selected, residual), axis=0)
379
+
380
+ # fit svm model, and get parameters for the hyperplane
381
+ clf = LinearSVC(random_state=0, C=1.0, fit_intercept=False, dual='auto')
382
+ clf.fit(X, y)
383
+ optimal_weight = clf.coef_[0].tolist()
384
+ print('optimal weight:', optimal_weight)
385
+ pass
386
+
387
+ elif method == 'Greedy':
388
+ for idx in selected.index:
389
+ # find which score is the highest, clip, mcos, or pop
390
+ clip_score = selected.loc[idx, 'norm_clip_crop']
391
+ mcos_score = selected.loc[idx, 'norm_mcos_crop']
392
+ pop_score = selected.loc[idx, 'norm_pop']
393
+ if clip_score >= mcos_score and clip_score >= pop_score:
394
+ optimal_weight[0] += 1
395
+ elif mcos_score >= clip_score and mcos_score >= pop_score:
396
+ optimal_weight[1] += 1
397
+ elif pop_score >= clip_score and pop_score >= mcos_score:
398
+ optimal_weight[2] += 1
399
+
400
+ # normalize optimal_weight
401
+ optimal_weight = [round(weight/len(selected), 2) for weight in optimal_weight]
402
+ print('optimal weight:', optimal_weight)
403
+
404
+ return optimal_weight
405
+
406
+
407
+
408
+
409
 
410
  @st.cache_data
411
  def load_hf_dataset():
 
441
 
442
  if __name__ == "__main__":
443
  st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
444
+
445
  if 'user_id' not in st.session_state:
446
  st.warning('Please log in first.')
447
  home_btn = st.button('Go to Home Page')
pages/__pycache__/Gallery.cpython-39.pyc CHANGED
Binary files a/pages/__pycache__/Gallery.cpython-39.pyc and b/pages/__pycache__/Gallery.cpython-39.pyc differ