Spaces:
Sleeping
Sleeping
change summary visual
Browse files- Home.py +9 -9
- pages/Summary.py +69 -52
Home.py
CHANGED
@@ -55,11 +55,11 @@ def info():
|
|
55 |
with st.sidebar:
|
56 |
st.write('## About')
|
57 |
st.write(
|
58 |
-
"
|
59 |
)
|
60 |
|
61 |
st.write(
|
62 |
-
"After picking images you liked from Gallery and a
|
63 |
)
|
64 |
|
65 |
|
@@ -88,15 +88,15 @@ if __name__ == '__main__':
|
|
88 |
st.write('### About GEMRec')
|
89 |
st.write("**GE**nerative **M**odel **Rec**ommendation (**GEMRec**) is a research project by [MAPS Lab](https://github.com/MAPS-research), NYU Shanghai.")
|
90 |
st.write('### Our Task')
|
91 |
-
st.write('Given a userβs preference
|
92 |
st.write('### Our Approach')
|
93 |
-
st.write('We propose a two-stage framework, which contains prompt-model
|
94 |
-
st.write('### Key Contributions')
|
95 |
-
st.write('1. We propose a two-stage framework to approach the Generative Model Recommendation problem. Our framework allows end-users to effectively explore a diverse set of generative models to understand their expressiveness. It also allows system developers to elicit user preferences for items generated from personalized prompts.')
|
96 |
-
st.write('2. We release GEMRec-18K, a dense prompt-model interaction dataset that consists of 18K images generated by pairing 200 generative models with 90 prompts collected from real-world usages, accompanied by detailed metadata and generation configurations. This dataset builds the cornerstone for exploring Generative Recommendation and can be useful for other tasks related to understanding generative models')
|
97 |
-
st.write('3. We take the first step in examining evaluation metrics for personalized image generations and identify several limitations in existing metrics. We propose a weighted metric that is more suitable for the task and opens up directions for future improvements in model training and evaluations.')
|
98 |
|
99 |
-
with st.expander(label='
|
100 |
st.write('### Paper')
|
101 |
st.write('Arxiv: [Towards Personalized Prompt-Model Retrieval for Generative Recommendation](https://arxiv.org/abs/2308.02205)')
|
102 |
st.write('### GEMRec-18K Dataset')
|
|
|
55 |
with st.sidebar:
|
56 |
st.write('## About')
|
57 |
st.write(
|
58 |
+
"This is a web application **for individual users to quickly dig out the most preferable text-to-image models from [civitai](https://civitai.com) for different prompts**. Our research aims to understand personal preference towards generative models and you can contribute by playing with this tool and giving us your feedback! "
|
59 |
)
|
60 |
|
61 |
st.write(
|
62 |
+
"After picking images you liked from Gallery and a Ranking Contest, a summary dashboard will be presented **indicating your preferred models with download links ready to be deployed in [Webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)** !"
|
63 |
)
|
64 |
|
65 |
|
|
|
88 |
st.write('### About GEMRec')
|
89 |
st.write("**GE**nerative **M**odel **Rec**ommendation (**GEMRec**) is a research project by [MAPS Lab](https://github.com/MAPS-research), NYU Shanghai.")
|
90 |
st.write('### Our Task')
|
91 |
+
st.write('Navigate hundreds of text-to-image models through various categories of pre-defined prompts and a graph-based interface. Given a userβs preference and interaction data, we aim to recommend the most preferred generative model for the user.')
|
92 |
st.write('### Our Approach')
|
93 |
+
st.write('We propose a two-stage framework, which contains prompt-model retrieval and generative model ranking. :red[Your participation in this web application will help us to improve our framework and to further our research on personalization.]')
|
94 |
+
# st.write('### Key Contributions')
|
95 |
+
# st.write('1. We propose a two-stage framework to approach the Generative Model Recommendation problem. Our framework allows end-users to effectively explore a diverse set of generative models to understand their expressiveness. It also allows system developers to elicit user preferences for items generated from personalized prompts.')
|
96 |
+
# st.write('2. We release GEMRec-18K, a dense prompt-model interaction dataset that consists of 18K images generated by pairing 200 generative models with 90 prompts collected from real-world usages, accompanied by detailed metadata and generation configurations. This dataset builds the cornerstone for exploring Generative Recommendation and can be useful for other tasks related to understanding generative models')
|
97 |
+
# st.write('3. We take the first step in examining evaluation metrics for personalized image generations and identify several limitations in existing metrics. We propose a weighted metric that is more suitable for the task and opens up directions for future improvements in model training and evaluations.')
|
98 |
|
99 |
+
with st.expander(label='**π Where can I find the paper and dataset?**'):
|
100 |
st.write('### Paper')
|
101 |
st.write('Arxiv: [Towards Personalized Prompt-Model Retrieval for Generative Recommendation](https://arxiv.org/abs/2308.02205)')
|
102 |
st.write('### GEMRec-18K Dataset')
|
pages/Summary.py
CHANGED
@@ -30,7 +30,7 @@ class DashboardApp:
|
|
30 |
|
31 |
def sidebar(self, tags, mode):
|
32 |
with st.sidebar:
|
33 |
-
tag = st.selectbox('Select a tag', tags, key='tag')
|
34 |
# st.write('---')
|
35 |
with st.form('summary_sidebar_form'):
|
36 |
st.write('## Want a more comprehensive summary?')
|
@@ -48,10 +48,10 @@ class DashboardApp:
|
|
48 |
# if submit_feedback:
|
49 |
# print(feedback)
|
50 |
|
51 |
-
return tag
|
52 |
|
53 |
def leaderboard(self, tag, db_table):
|
54 |
-
tag = '%' if tag == '
|
55 |
|
56 |
# get the ranking results of the current user
|
57 |
curser = RANKING_CONN.cursor()
|
@@ -65,39 +65,40 @@ class DashboardApp:
|
|
65 |
# sort the modelVersion_standings by value into a list of tuples in descending order
|
66 |
st.session_state.modelVersion_standings[tag] = sorted(st.session_state.modelVersion_standings[tag].items(), key=lambda x: x[1], reverse=True)
|
67 |
|
68 |
-
tab1, tab2 = st.tabs(['Top Picks', 'Detailed Info'])
|
69 |
|
70 |
-
with tab1:
|
71 |
# self.podium(modelVersion_standings)
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
with tab2:
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
101 |
|
102 |
def podium_expander(self, tag, example_prompts, n=3, summary_mode: ['display', 'edit'] = 'display'):
|
103 |
|
@@ -110,29 +111,40 @@ class DashboardApp:
|
|
110 |
icon = 'π₯'if i == 0 else 'π₯' if i == 1 else 'π₯' if i == 2 else 'π'
|
111 |
podium_display = st.columns([1, 14], gap='medium')
|
112 |
with podium_display[0]:
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
119 |
with podium_display[1]:
|
120 |
-
title_display = st.columns([3,
|
121 |
with title_display[0]:
|
122 |
st.write(f'##### {model_name}, {modelVersion_name}')
|
123 |
# st.write(f'Ranking Score: {winning_times}')
|
124 |
with title_display[1]:
|
125 |
# image_display = st.selectbox('image display', ['Featured', 'All Images'], key=f'image_display_{modelVersion_id}', label_visibility='collapsed')
|
126 |
-
image_display = st.
|
127 |
|
128 |
with title_display[2]:
|
129 |
-
st.link_button('Download
|
130 |
with title_display[3]:
|
131 |
-
st.link_button('Civitai
|
132 |
# st.write(f'[Civitai Page](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id}), [Model Download Link]({url}), Ranking Score: {winning_times}')
|
133 |
# with st.expander(f'**{icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id})**, Ranking Score: {winning_times}'):
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
if not image_display:
|
138 |
example_images = self.promptBook[self.promptBook['prompt_id'].isin(example_prompts) & (self.promptBook['modelVersion_id']==modelVersion_id)]['image_id'].values
|
@@ -143,11 +155,10 @@ class DashboardApp:
|
|
143 |
)
|
144 |
|
145 |
else:
|
146 |
-
st.toast('π It may take a while to load all images. Please be patient.')
|
147 |
# with st.expander(f'Show Images'):
|
148 |
images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values
|
149 |
|
150 |
-
safety_check = st.
|
151 |
unsafe_prompts = json.load(open('data/unsafe_prompts.json', 'r'))
|
152 |
# merge dict values into one list
|
153 |
unsafe_prompts = [item for sublist in unsafe_prompts.values() for item in sublist]
|
@@ -162,6 +173,7 @@ class DashboardApp:
|
|
162 |
images,
|
163 |
img_style={"margin": "5px", "height": "100px"}
|
164 |
)
|
|
|
165 |
|
166 |
# # st.write(f'### Images generated with {icon} {model_name}, {modelVersion_name}')
|
167 |
# col_num = 4
|
@@ -212,7 +224,7 @@ class DashboardApp:
|
|
212 |
# get tags from database of the current user
|
213 |
db_table = 'sort_results' if mode == 'Drag and Sort' else 'battle_results'
|
214 |
|
215 |
-
tags = ['
|
216 |
curser = RANKING_CONN.cursor()
|
217 |
curser.execute(
|
218 |
f"SELECT DISTINCT tag FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}'")
|
@@ -220,11 +232,13 @@ class DashboardApp:
|
|
220 |
tags.append(row['tag'])
|
221 |
curser.close()
|
222 |
|
223 |
-
if tags == ['
|
224 |
st.info(f'No rankings are finished with {mode} mode yet.')
|
225 |
|
226 |
else:
|
227 |
-
|
|
|
|
|
228 |
self.leaderboard(tag, db_table)
|
229 |
|
230 |
with st.sidebar:
|
@@ -240,7 +254,7 @@ class DashboardApp:
|
|
240 |
RANKING_CONN.commit()
|
241 |
curser.close()
|
242 |
|
243 |
-
st.toast('Thanks for your feedback! We will take it into consideration in our future work
|
244 |
|
245 |
|
246 |
if __name__ == "__main__":
|
@@ -280,4 +294,7 @@ if __name__ == "__main__":
|
|
280 |
app = DashboardApp(roster, promptBook, session_finished)
|
281 |
app.app()
|
282 |
|
|
|
|
|
|
|
283 |
|
|
|
30 |
|
31 |
def sidebar(self, tags, mode):
|
32 |
with st.sidebar:
|
33 |
+
# tag = st.selectbox('Select a tag', tags, key='tag')
|
34 |
# st.write('---')
|
35 |
with st.form('summary_sidebar_form'):
|
36 |
st.write('## Want a more comprehensive summary?')
|
|
|
48 |
# if submit_feedback:
|
49 |
# print(feedback)
|
50 |
|
51 |
+
# return tag
|
52 |
|
53 |
def leaderboard(self, tag, db_table):
|
54 |
+
tag = '%' if tag == 'overview' else tag
|
55 |
|
56 |
# get the ranking results of the current user
|
57 |
curser = RANKING_CONN.cursor()
|
|
|
65 |
# sort the modelVersion_standings by value into a list of tuples in descending order
|
66 |
st.session_state.modelVersion_standings[tag] = sorted(st.session_state.modelVersion_standings[tag].items(), key=lambda x: x[1], reverse=True)
|
67 |
|
68 |
+
# tab1, tab2 = st.tabs(['Top Picks', 'Detailed Info'])
|
69 |
|
70 |
+
# with tab1:
|
71 |
# self.podium(modelVersion_standings)
|
72 |
+
# switch_stage = st.toggle('Manual Reorder', key='switch_stage')
|
73 |
+
|
74 |
+
example_prompts = []
|
75 |
+
# get example images
|
76 |
+
for key, value in st.session_state.selected_dict.items():
|
77 |
+
for model in st.session_state.modelVersion_standings[tag]:
|
78 |
+
if model[0] in value:
|
79 |
+
example_prompts.append(key)
|
80 |
+
|
81 |
+
# if switch_stage:
|
82 |
+
# self.podium_expander(tag, n=len(st.session_state.modelVersion_standings[tag]), summary_mode='edit', example_prompts=example_prompts)
|
83 |
+
# else:
|
84 |
+
self.podium_expander(tag, n=len(st.session_state.modelVersion_standings[tag]), summary_mode='display', example_prompts=example_prompts)
|
85 |
+
# if st.session_state.summary_mode == 'display':
|
86 |
+
# switch_stage = st.button('Manual Reorder', key='switch_stage_edit', on_click=lambda: st.session_state.__setitem__('summary_mode', 'edit'))
|
87 |
+
# self.podium_expander(tag, n=3, summary_mode='display')
|
88 |
+
#
|
89 |
+
# elif st.session_state.summary_mode == 'edit':
|
90 |
+
# switch_stage = st.button('Done', key='switch_stage_done', type='primary', on_click=lambda: st.session_state.__setitem__('summary_mode', 'display'))
|
91 |
+
# self.podium_expander(tag, n=len(st.session_state.modelVersion_standings[tag]), summary_mode='edit')
|
92 |
+
|
93 |
+
# with tab2:
|
94 |
+
st.write('---')
|
95 |
+
st.write('**Detailed information of all selected models**')
|
96 |
+
detailed_info = pd.merge(pd.DataFrame(st.session_state.modelVersion_standings[tag], columns=['modelVersion_id', 'ranking_score']), self.roster, on='modelVersion_id')
|
97 |
+
|
98 |
+
detailed_info = detailed_info[['model_name', 'modelVersion_name', 'model_download_count', 'tag', 'baseModel']]
|
99 |
+
|
100 |
+
st.data_editor(detailed_info, hide_index=False, disabled=True)
|
101 |
+
st.caption('You can click the header to sort the table by that column.')
|
102 |
|
103 |
def podium_expander(self, tag, example_prompts, n=3, summary_mode: ['display', 'edit'] = 'display'):
|
104 |
|
|
|
111 |
icon = 'π₯'if i == 0 else 'π₯' if i == 1 else 'π₯' if i == 2 else 'π'
|
112 |
podium_display = st.columns([1, 14], gap='medium')
|
113 |
with podium_display[0]:
|
114 |
+
st.title(f'{icon}')
|
115 |
+
# if summary_mode == 'display':
|
116 |
+
# st.title(f'{icon}')
|
117 |
+
# elif summary_mode == 'edit':
|
118 |
+
# settop = st.button('π', key=f'settop_{modelVersion_id}', help='Set this model to the top', disabled=i == 0, on_click=self.switch_order, args=(tag, i, 0), use_container_width=True)
|
119 |
+
# moveup = st.button('β¬', key=f'moveup_{modelVersion_id}', help='Move this model up', disabled=i == 0, on_click=self.switch_order, args=(tag, i, i - 1), use_container_width=True)
|
120 |
+
# movedown = st.button('β¬', key=f'movedown_{modelVersion_id}', help='Move this model down', disabled=i == n - 1, on_click=self.switch_order, args=(tag, i, i + 1), use_container_width=True)
|
121 |
with podium_display[1]:
|
122 |
+
title_display = st.columns([3.5, 2, 2, 2, 0.5, 0.5, 0.5])
|
123 |
with title_display[0]:
|
124 |
st.write(f'##### {model_name}, {modelVersion_name}')
|
125 |
# st.write(f'Ranking Score: {winning_times}')
|
126 |
with title_display[1]:
|
127 |
# image_display = st.selectbox('image display', ['Featured', 'All Images'], key=f'image_display_{modelVersion_id}', label_visibility='collapsed')
|
128 |
+
image_display = st.toggle('Show all images', key=f'image_display_{modelVersion_id}')
|
129 |
|
130 |
with title_display[2]:
|
131 |
+
st.link_button('Download', url, use_container_width=True)
|
132 |
with title_display[3]:
|
133 |
+
st.link_button('Civitai', f'https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id}', use_container_width=True, type='primary')
|
134 |
# st.write(f'[Civitai Page](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id}), [Model Download Link]({url}), Ranking Score: {winning_times}')
|
135 |
# with st.expander(f'**{icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id})**, Ranking Score: {winning_times}'):
|
136 |
+
with title_display[4]:
|
137 |
+
settop = st.button('π', key=f'settop_{modelVersion_id}', help='Set this model to the top',
|
138 |
+
disabled=i == 0, on_click=self.switch_order, args=(tag, i, 0),
|
139 |
+
use_container_width=True)
|
140 |
+
with title_display[5]:
|
141 |
+
moveup = st.button('β¬', key=f'moveup_{modelVersion_id}', help='Move this model up',
|
142 |
+
disabled=i == 0, on_click=self.switch_order, args=(tag, i, i - 1),
|
143 |
+
use_container_width=True)
|
144 |
+
with title_display[6]:
|
145 |
+
movedown = st.button('β¬', key=f'movedown_{modelVersion_id}', help='Move this model down',
|
146 |
+
disabled=i == n - 1, on_click=self.switch_order, args=(tag, i, i + 1),
|
147 |
+
use_container_width=True)
|
148 |
|
149 |
if not image_display:
|
150 |
example_images = self.promptBook[self.promptBook['prompt_id'].isin(example_prompts) & (self.promptBook['modelVersion_id']==modelVersion_id)]['image_id'].values
|
|
|
155 |
)
|
156 |
|
157 |
else:
|
|
|
158 |
# with st.expander(f'Show Images'):
|
159 |
images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values
|
160 |
|
161 |
+
safety_check = st.toggle('Include potentially unsafe or offensive images', value=False, key=modelVersion_id)
|
162 |
unsafe_prompts = json.load(open('data/unsafe_prompts.json', 'r'))
|
163 |
# merge dict values into one list
|
164 |
unsafe_prompts = [item for sublist in unsafe_prompts.values() for item in sublist]
|
|
|
173 |
images,
|
174 |
img_style={"margin": "5px", "height": "100px"}
|
175 |
)
|
176 |
+
st.write('π It may take a while to load all images. Please be patient, and **NEVER USE THE REFRESH BUTTON ON YOUR BROWSER**.')
|
177 |
|
178 |
# # st.write(f'### Images generated with {icon} {model_name}, {modelVersion_name}')
|
179 |
# col_num = 4
|
|
|
224 |
# get tags from database of the current user
|
225 |
db_table = 'sort_results' if mode == 'Drag and Sort' else 'battle_results'
|
226 |
|
227 |
+
tags = ['overview']
|
228 |
curser = RANKING_CONN.cursor()
|
229 |
curser.execute(
|
230 |
f"SELECT DISTINCT tag FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}'")
|
|
|
232 |
tags.append(row['tag'])
|
233 |
curser.close()
|
234 |
|
235 |
+
if tags == ['overview']:
|
236 |
st.info(f'No rankings are finished with {mode} mode yet.')
|
237 |
|
238 |
else:
|
239 |
+
tags = tags[0:1] if len(tags) == 2 else tags
|
240 |
+
tag = st.radio('Select a tag', tags, index=0, horizontal=True, label_visibility='collapsed')
|
241 |
+
self.sidebar(tags, mode)
|
242 |
self.leaderboard(tag, db_table)
|
243 |
|
244 |
with st.sidebar:
|
|
|
254 |
RANKING_CONN.commit()
|
255 |
curser.close()
|
256 |
|
257 |
+
st.toast('π **Thanks for your feedback! We will take it into consideration in our future work.**')
|
258 |
|
259 |
|
260 |
if __name__ == "__main__":
|
|
|
294 |
app = DashboardApp(roster, promptBook, session_finished)
|
295 |
app.app()
|
296 |
|
297 |
+
with open('./css/style.css') as f:
|
298 |
+
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
|
299 |
+
|
300 |
|