Spaces:
Sleeping
Sleeping
add set top methods
Browse files- pages/Gallery.py +24 -24
pages/Gallery.py
CHANGED
@@ -272,21 +272,20 @@ class GalleryApp:
|
|
272 |
if safety_check:
|
273 |
items, info, col_num, preprocessor = self.selection_panel(items)
|
274 |
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
# st.write(st.session_state.selected_dict.get(prompt_id, []))
|
290 |
|
291 |
with st.form(key=f'{prompt_id}'):
|
292 |
# buttons = st.columns([1, 1, 1])
|
@@ -315,9 +314,6 @@ class GalleryApp:
|
|
315 |
with st.spinner('Loading images...'):
|
316 |
self.gallery_standard(items, col_num, info)
|
317 |
|
318 |
-
with st.sidebar:
|
319 |
-
st.write(str(st.session_state.selected_dict[prompt_id]))
|
320 |
-
|
321 |
def submit_actions(self, status, prompt_id):
|
322 |
if status == 'Select':
|
323 |
modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
|
@@ -329,7 +325,6 @@ class GalleryApp:
|
|
329 |
print(st.session_state.selected_dict, 'deselect')
|
330 |
st.experimental_rerun()
|
331 |
# self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
|
332 |
-
pass
|
333 |
elif status == 'Continue':
|
334 |
st.session_state.selected_dict[prompt_id] = []
|
335 |
for key in st.session_state:
|
@@ -339,10 +334,13 @@ class GalleryApp:
|
|
339 |
st.session_state.selected_dict[prompt_id].append(int(keys[2]))
|
340 |
# switch_page("ranking")
|
341 |
print(st.session_state.selected_dict, 'continue')
|
342 |
-
|
343 |
|
344 |
-
def dynamic_weight(self,
|
|
|
|
|
345 |
optimal_weight = [0, 0, 0]
|
|
|
346 |
if method == 'Grid Search':
|
347 |
# grid search method
|
348 |
top_ranking = len(items) * len(selected)
|
@@ -350,9 +348,10 @@ class GalleryApp:
|
|
350 |
for clip_weight in np.arange(-1, 1, 0.1):
|
351 |
for mcos_weight in np.arange(-1, 1, 0.1):
|
352 |
for pop_weight in np.arange(-1, 1, 0.1):
|
353 |
-
weight_all = clip_weight*items[f'norm_clip_{preprocessor}'] + mcos_weight*items[f'norm_mcos_{preprocessor}'] + pop_weight*items['norm_pop']
|
354 |
-
weight_all_sorted = weight_all.sort_values(ascending=False)
|
355 |
|
|
|
|
|
|
|
356 |
weight_selected = clip_weight*selected[f'norm_clip_{preprocessor}'] + mcos_weight*selected[f'norm_mcos_{preprocessor}'] + pop_weight*selected['norm_pop']
|
357 |
|
358 |
# get the index of values of weight_selected in weight_all_sorted
|
@@ -361,6 +360,7 @@ class GalleryApp:
|
|
361 |
rankings.append(weight_all_sorted.index[weight_all_sorted == weight].tolist()[0])
|
362 |
if sum(rankings) <= top_ranking:
|
363 |
top_ranking = sum(rankings)
|
|
|
364 |
optimal_weight = [clip_weight, mcos_weight, pop_weight]
|
365 |
print('optimal weight:', optimal_weight)
|
366 |
|
@@ -401,7 +401,7 @@ class GalleryApp:
|
|
401 |
optimal_weight = [round(weight/len(selected), 2) for weight in optimal_weight]
|
402 |
print('optimal weight:', optimal_weight)
|
403 |
|
404 |
-
|
405 |
|
406 |
|
407 |
|
|
|
272 |
if safety_check:
|
273 |
items, info, col_num, preprocessor = self.selection_panel(items)
|
274 |
|
275 |
+
if 'selected_dict' in st.session_state:
|
276 |
+
st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
|
277 |
+
dynamic_weight_options = ['Grid Search', 'SVM', 'Greedy']
|
278 |
+
dynamic_weight_panel = st.columns(len(dynamic_weight_options))
|
279 |
+
|
280 |
+
if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
|
281 |
+
btn_disable = False
|
282 |
+
else:
|
283 |
+
btn_disable = True
|
284 |
+
|
285 |
+
for i in range(len(dynamic_weight_options)):
|
286 |
+
method = dynamic_weight_options[i]
|
287 |
+
with dynamic_weight_panel[i]:
|
288 |
+
btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, preprocessor, method))
|
|
|
289 |
|
290 |
with st.form(key=f'{prompt_id}'):
|
291 |
# buttons = st.columns([1, 1, 1])
|
|
|
314 |
with st.spinner('Loading images...'):
|
315 |
self.gallery_standard(items, col_num, info)
|
316 |
|
|
|
|
|
|
|
317 |
def submit_actions(self, status, prompt_id):
|
318 |
if status == 'Select':
|
319 |
modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
|
|
|
325 |
print(st.session_state.selected_dict, 'deselect')
|
326 |
st.experimental_rerun()
|
327 |
# self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
|
|
|
328 |
elif status == 'Continue':
|
329 |
st.session_state.selected_dict[prompt_id] = []
|
330 |
for key in st.session_state:
|
|
|
334 |
st.session_state.selected_dict[prompt_id].append(int(keys[2]))
|
335 |
# switch_page("ranking")
|
336 |
print(st.session_state.selected_dict, 'continue')
|
337 |
+
st.experimental_rerun()
|
338 |
|
339 |
+
def dynamic_weight(self, prompt_id, items, preprocessor='crop', method='Grid Search'):
|
340 |
+
selected = items[
|
341 |
+
items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
|
342 |
optimal_weight = [0, 0, 0]
|
343 |
+
|
344 |
if method == 'Grid Search':
|
345 |
# grid search method
|
346 |
top_ranking = len(items) * len(selected)
|
|
|
348 |
for clip_weight in np.arange(-1, 1, 0.1):
|
349 |
for mcos_weight in np.arange(-1, 1, 0.1):
|
350 |
for pop_weight in np.arange(-1, 1, 0.1):
|
|
|
|
|
351 |
|
352 |
+
weight_all = clip_weight*items[f'norm_clip_{preprocessor}'] + mcos_weight*items[f'norm_mcos_{preprocessor}'] + pop_weight*items['norm_pop']
|
353 |
+
weight_all_sorted = weight_all.sort_values(ascending=False).reset_index(drop=True)
|
354 |
+
# print('weight_all_sorted:', weight_all_sorted)
|
355 |
weight_selected = clip_weight*selected[f'norm_clip_{preprocessor}'] + mcos_weight*selected[f'norm_mcos_{preprocessor}'] + pop_weight*selected['norm_pop']
|
356 |
|
357 |
# get the index of values of weight_selected in weight_all_sorted
|
|
|
360 |
rankings.append(weight_all_sorted.index[weight_all_sorted == weight].tolist()[0])
|
361 |
if sum(rankings) <= top_ranking:
|
362 |
top_ranking = sum(rankings)
|
363 |
+
print('current top ranking:', top_ranking, rankings)
|
364 |
optimal_weight = [clip_weight, mcos_weight, pop_weight]
|
365 |
print('optimal weight:', optimal_weight)
|
366 |
|
|
|
401 |
optimal_weight = [round(weight/len(selected), 2) for weight in optimal_weight]
|
402 |
print('optimal weight:', optimal_weight)
|
403 |
|
404 |
+
st.session_state.score_weights[0: 3] = optimal_weight
|
405 |
|
406 |
|
407 |
|