Spaces:
Running
Running
first version of ranking page!
Browse files- data/ranking_script.py +16 -0
- pages/Gallery.py +4 -2
- pages/Ranking.py +119 -40
data/ranking_script.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import Dataset
|
2 |
+
|
3 |
+
|
4 |
+
def init_ranking_data():
|
5 |
+
ds = Dataset.from_dict({'image_id': [], 'modelVersion_id': [], 'ranking': [], "user_name": [], "timestamp": []})\
|
6 |
+
|
7 |
+
# add example data
|
8 |
+
# note that image_id is a string, other ids are int
|
9 |
+
ds = ds.add_item({'image_id': '0', 'modelVersion_id': 0, 'ranking': 0, "user_name": "example_data", "timestamp": 0.0})
|
10 |
+
|
11 |
+
ds.push_to_hub("MAPS-research/GEMRec-Ranking", split='train')
|
12 |
+
|
13 |
+
|
14 |
+
if __name__ == '__main__':
|
15 |
+
init_ranking_data()
|
16 |
+
|
pages/Gallery.py
CHANGED
@@ -278,6 +278,8 @@ class GalleryApp:
|
|
278 |
switch_page("ranking")
|
279 |
|
280 |
def submit_actions(self, status, prompt_id):
|
|
|
|
|
281 |
if status == 'Select':
|
282 |
modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
|
283 |
st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
|
@@ -400,8 +402,8 @@ def load_hf_dataset():
|
|
400 |
login(token=os.environ.get("HF_TOKEN"))
|
401 |
|
402 |
# load from huggingface
|
403 |
-
roster = pd.DataFrame(load_dataset('
|
404 |
-
promptBook = pd.DataFrame(load_dataset('
|
405 |
# images_ds = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
|
406 |
images_ds = None # set to None for now since we use s3 bucket to store images
|
407 |
|
|
|
278 |
switch_page("ranking")
|
279 |
|
280 |
def submit_actions(self, status, prompt_id):
|
281 |
+
# remove counter from session state
|
282 |
+
st.session_state.pop('counter', None)
|
283 |
if status == 'Select':
|
284 |
modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
|
285 |
st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
|
|
|
402 |
login(token=os.environ.get("HF_TOKEN"))
|
403 |
|
404 |
# load from huggingface
|
405 |
+
roster = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Roster', split='train'))
|
406 |
+
promptBook = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Metadata', split='train'))
|
407 |
# images_ds = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
|
408 |
images_ds = None # set to None for now since we use s3 bucket to store images
|
409 |
|
pages/Ranking.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
import streamlit as st
|
@@ -17,7 +18,7 @@ class RankingApp:
|
|
17 |
# self.batch_num += 1 if len(self.promptBook) % self.batch_size != 0 else 0
|
18 |
|
19 |
if 'counter' not in st.session_state:
|
20 |
-
st.session_state.counter =
|
21 |
|
22 |
def sidebar(self):
|
23 |
with st.sidebar:
|
@@ -37,18 +38,27 @@ class RankingApp:
|
|
37 |
# input image metadata
|
38 |
prompt = st.text_area('Prompt', selected_prompt, height=150, key='prompt', disabled=True)
|
39 |
negative_prompt = st.text_area('Negative Prompt', items['negativePrompt'].unique()[0], height=150, key='negative_prompt', disabled=True)
|
40 |
-
st.form_submit_button('Generate Images', type='primary', use_container_width=True)
|
41 |
|
42 |
return prompt_tags, tag, prompt_id, items
|
43 |
|
44 |
-
def draggable_images(self, items, layout='portrait'):
|
45 |
# init ranking by the order of items
|
|
|
46 |
if 'ranking' not in st.session_state:
|
47 |
st.session_state.ranking = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
for i in range(len(items)):
|
49 |
-
st.session_state.ranking[str(items['image_id'][i])] = i
|
|
|
|
|
|
|
50 |
|
51 |
-
print(items)
|
52 |
with elements('dashboard'):
|
53 |
if layout == 'portrait':
|
54 |
col_num = 4
|
@@ -57,14 +67,17 @@ class RankingApp:
|
|
57 |
elif layout == 'landscape':
|
58 |
col_num = 2
|
59 |
layout = [
|
60 |
-
dashboard.Item(str(items['image_id'][i]), i % col_num * 2, i // col_num, 2, 1.
|
61 |
i in range(len(items))
|
62 |
]
|
63 |
|
64 |
with dashboard.Grid(layout, cols={'lg': 4, 'md': 4, 'sm': 4, 'xs': 4, 'xxs': 2}, onLayoutChange=self.handle_layout_change, margin=[18, 18], containerPadding=[0, 0]):
|
65 |
for i in range(len(layout)):
|
66 |
with mui.Card(key=str(items['image_id'][i]), variant="outlined"):
|
67 |
-
|
|
|
|
|
|
|
68 |
|
69 |
mui.Chip(label=rank,
|
70 |
# variant="outlined" if rank!=1 else "default",
|
@@ -79,7 +92,7 @@ class RankingApp:
|
|
79 |
# image={"data:image/png;base64", img_str},
|
80 |
image=img_url,
|
81 |
alt="There should be an image",
|
82 |
-
sx={"height": "100%", "object-fit": "
|
83 |
)
|
84 |
|
85 |
def handle_layout_change(self, updated_layout):
|
@@ -87,26 +100,95 @@ class RankingApp:
|
|
87 |
sorted_list = sorted(updated_layout, key=lambda x: (x['y'], x['x']))
|
88 |
sorted_list = [str(item['i']) for item in sorted_list]
|
89 |
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
92 |
|
93 |
def app(self):
|
94 |
st.title('Personal Image Ranking')
|
95 |
st.write('Here you can test out your selected images with any prompt you like.')
|
96 |
# st.write(self.promptBook)
|
97 |
|
|
|
|
|
|
|
|
|
|
|
98 |
prompt_tags, tag, prompt_id, items = self.sidebar()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
-
sorting, control = st.columns((11, 1), gap='large')
|
101 |
-
with sorting:
|
102 |
-
# st.write('## Sorting')
|
103 |
-
# st.write('Please drag the images to sort them.')
|
104 |
-
st.progress((st.session_state.counter + 1) / self.batch_num, text=f"Batch {st.session_state.counter + 1} / {self.batch_num}")
|
105 |
-
self.draggable_images(items.iloc[self.batch_size*st.session_state.counter: self.batch_size*(st.session_state.counter+1)], layout='portrait')
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
|
112 |
if __name__ == "__main__":
|
@@ -119,13 +201,15 @@ if __name__ == "__main__":
|
|
119 |
switch_page("home")
|
120 |
|
121 |
else:
|
122 |
-
|
|
|
123 |
for key, value in st.session_state.selected_dict.items():
|
124 |
for v in value:
|
125 |
-
if v
|
126 |
-
|
|
|
127 |
|
128 |
-
if
|
129 |
st.info('You have not checked any image yet. Please go back to the gallery page and check some images.')
|
130 |
gallery_btn = st.button('Go to Gallery')
|
131 |
if gallery_btn:
|
@@ -134,21 +218,16 @@ if __name__ == "__main__":
|
|
134 |
# st.write('You have checked ' + str(len(selected_modelVersions)) + ' images.')
|
135 |
roster, promptBook, images_ds = load_hf_dataset()
|
136 |
print(st.session_state.selected_dict)
|
137 |
-
st.write("# Full function is coming soon.")
|
138 |
-
|
139 |
-
st.
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
# promptBook_selected = promptBook_selected.reset_index(drop=True)
|
150 |
-
# images_endpoint = "https://modelcofferbucket.s3-accelerate.amazonaws.com/"
|
151 |
-
#
|
152 |
-
# app = RankingApp(promptBook_selected, images_endpoint, batch_size=4)
|
153 |
-
# app.app()
|
154 |
|
|
|
1 |
+
import datasets
|
2 |
import numpy as np
|
3 |
import pandas as pd
|
4 |
import streamlit as st
|
|
|
18 |
# self.batch_num += 1 if len(self.promptBook) % self.batch_size != 0 else 0
|
19 |
|
20 |
if 'counter' not in st.session_state:
|
21 |
+
st.session_state.counter = {}
|
22 |
|
23 |
def sidebar(self):
|
24 |
with st.sidebar:
|
|
|
38 |
# input image metadata
|
39 |
prompt = st.text_area('Prompt', selected_prompt, height=150, key='prompt', disabled=True)
|
40 |
negative_prompt = st.text_area('Negative Prompt', items['negativePrompt'].unique()[0], height=150, key='negative_prompt', disabled=True)
|
41 |
+
st.form_submit_button('Generate Images [Coming Soon]', type='primary', use_container_width=True, disabled=True)
|
42 |
|
43 |
return prompt_tags, tag, prompt_id, items
|
44 |
|
45 |
+
def draggable_images(self, items, prompt_id, layout='portrait'):
|
46 |
# init ranking by the order of items
|
47 |
+
|
48 |
if 'ranking' not in st.session_state:
|
49 |
st.session_state.ranking = {}
|
50 |
+
|
51 |
+
if prompt_id not in st.session_state.ranking:
|
52 |
+
st.session_state.ranking[prompt_id] = {}
|
53 |
+
|
54 |
+
if st.session_state.counter[prompt_id] not in st.session_state.ranking[prompt_id]:
|
55 |
+
st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]] = {}
|
56 |
for i in range(len(items)):
|
57 |
+
st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]][str(items['image_id'][i])] = i
|
58 |
+
else:
|
59 |
+
# set the index of items to the corresponding ranking value of the image_id
|
60 |
+
items.index = items['image_id'].apply(lambda x: st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]][str(x)])
|
61 |
|
|
|
62 |
with elements('dashboard'):
|
63 |
if layout == 'portrait':
|
64 |
col_num = 4
|
|
|
67 |
elif layout == 'landscape':
|
68 |
col_num = 2
|
69 |
layout = [
|
70 |
+
dashboard.Item(str(items['image_id'][i]), i % col_num * 2, i // col_num, 2, 1.6, isResizable=False) for
|
71 |
i in range(len(items))
|
72 |
]
|
73 |
|
74 |
with dashboard.Grid(layout, cols={'lg': 4, 'md': 4, 'sm': 4, 'xs': 4, 'xxs': 2}, onLayoutChange=self.handle_layout_change, margin=[18, 18], containerPadding=[0, 0]):
|
75 |
for i in range(len(layout)):
|
76 |
with mui.Card(key=str(items['image_id'][i]), variant="outlined"):
|
77 |
+
prompt_id = st.session_state.prompt_id_tmp
|
78 |
+
batch_idx = st.session_state.counter[prompt_id]
|
79 |
+
|
80 |
+
rank = st.session_state.ranking[prompt_id][batch_idx][str(items['image_id'][i])] + 1
|
81 |
|
82 |
mui.Chip(label=rank,
|
83 |
# variant="outlined" if rank!=1 else "default",
|
|
|
92 |
# image={"data:image/png;base64", img_str},
|
93 |
image=img_url,
|
94 |
alt="There should be an image",
|
95 |
+
sx={"height": "100%", "object-fit": "contain", 'bgcolor': 'black'},
|
96 |
)
|
97 |
|
98 |
def handle_layout_change(self, updated_layout):
|
|
|
100 |
sorted_list = sorted(updated_layout, key=lambda x: (x['y'], x['x']))
|
101 |
sorted_list = [str(item['i']) for item in sorted_list]
|
102 |
|
103 |
+
prompt_id = st.session_state.prompt_id_tmp
|
104 |
+
batch_idx = st.session_state.counter[prompt_id]
|
105 |
+
|
106 |
+
for k in st.session_state.ranking[prompt_id][batch_idx].keys():
|
107 |
+
st.session_state.ranking[prompt_id][batch_idx][k] = sorted_list.index(k)
|
108 |
|
109 |
def app(self):
|
110 |
st.title('Personal Image Ranking')
|
111 |
st.write('Here you can test out your selected images with any prompt you like.')
|
112 |
# st.write(self.promptBook)
|
113 |
|
114 |
+
# save the current progress to session state
|
115 |
+
if 'progress' not in st.session_state:
|
116 |
+
st.session_state.progress = {}
|
117 |
+
# print('current progress: ', st.session_state.progress)
|
118 |
+
|
119 |
prompt_tags, tag, prompt_id, items = self.sidebar()
|
120 |
+
batch_num = len(items) // self.batch_size
|
121 |
+
batch_num += 1 if len(items) % self.batch_size != 0 else 0
|
122 |
+
|
123 |
+
st.session_state.counter[prompt_id] = 0 if prompt_id not in st.session_state.counter else st.session_state.counter[prompt_id]
|
124 |
+
|
125 |
+
# save prompt_id in session state
|
126 |
+
st.session_state.prompt_id_tmp = prompt_id
|
127 |
+
|
128 |
+
if prompt_id not in st.session_state.progress:
|
129 |
+
st.session_state.progress[prompt_id] = 'ranking'
|
130 |
+
|
131 |
+
if st.session_state.progress[prompt_id] == 'ranking':
|
132 |
+
sorting, control = st.columns((11, 1), gap='large')
|
133 |
+
with sorting:
|
134 |
+
# st.write('## Sorting')
|
135 |
+
# st.write('Please drag the images to sort them.')
|
136 |
+
st.progress((st.session_state.counter[prompt_id] + 1) / batch_num, text=f"Batch {st.session_state.counter[prompt_id] + 1} / {batch_num}")
|
137 |
+
# st.write(items.iloc[self.batch_size*st.session_state.counter[prompt_id]: self.batch_size*(st.session_state.counter[prompt_id]+1)])
|
138 |
+
|
139 |
+
width, height = items.loc[0, 'size'].split('x')
|
140 |
+
if int(height) >= int(width):
|
141 |
+
self.draggable_images(items.iloc[self.batch_size*st.session_state.counter[prompt_id]: self.batch_size*(st.session_state.counter[prompt_id]+1)].reset_index(drop=True), prompt_id=prompt_id, layout='portrait')
|
142 |
+
else:
|
143 |
+
self.draggable_images(items.iloc[self.batch_size*st.session_state.counter[prompt_id]: self.batch_size*(st.session_state.counter[prompt_id]+1)].reset_index(drop=True), prompt_id=prompt_id, layout='landscape')
|
144 |
+
# st.write(str(st.session_state.ranking))
|
145 |
+
with control:
|
146 |
+
if st.session_state.counter[prompt_id] < batch_num - 1:
|
147 |
+
st.button(":arrow_right:", key='next', on_click=self.next_batch, help='Next Batch', kwargs={'prompt_id': prompt_id})
|
148 |
+
else:
|
149 |
+
st.button(":ballot_box_with_check:", key='finished', on_click=self.next_batch, help='Finished', kwargs={'prompt_id': prompt_id, 'progress': 'finished'})
|
150 |
+
|
151 |
+
if st.session_state.counter[prompt_id] > 0:
|
152 |
+
st.button(":arrow_left:", key='prev', on_click=self.prev_batch, help='Previous Batch', kwargs={'prompt_id': prompt_id})
|
153 |
+
|
154 |
+
elif st.session_state.progress[prompt_id] == 'finished':
|
155 |
+
st.write('## You have ranked all models for this tag!')
|
156 |
+
st.write('Thank you for your participation! Feel free to do the following things:')
|
157 |
+
st.write('* Rank for other tags and prompts.')
|
158 |
+
st.write('* Back to the gallery page to see more images.')
|
159 |
+
st.write('* Rank again for this tag and prompt.')
|
160 |
+
st.write('*More functions are coming soon... Please stay tuned*')
|
161 |
+
|
162 |
+
gallery_btn = st.button('🖼️ Back to Gallery')
|
163 |
+
if gallery_btn:
|
164 |
+
switch_page('gallery')
|
165 |
+
|
166 |
+
restart_btn = st.button('🎖️ Rank Again')
|
167 |
+
if restart_btn:
|
168 |
+
st.session_state.progress['prompt_id'] = 'ranking'
|
169 |
+
st.session_state.counter[prompt_id] = 0
|
170 |
+
st.experimental_rerun()
|
171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
+
def next_batch(self, prompt_id, progress=None):
|
174 |
+
|
175 |
+
# save ranking to dataset
|
176 |
+
# print(st.session_state.ranking)
|
177 |
+
ranking_dataset = datasets.load_dataset('MAPS-research/GEMRec-Ranking', split='train')
|
178 |
+
for image_id in st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]].keys():
|
179 |
+
modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_id]['modelVersion_id'].values[0]
|
180 |
+
ranking = st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]][image_id]
|
181 |
+
# print({'image_id': image_id, 'modelVersion_id': modelVersion_id, 'ranking': ranking, "user_name": st.session_state.user_id[0], "timestamp": st.session_state.user_id[1]})
|
182 |
+
ranking_dataset = ranking_dataset.add_item({'image_id': image_id, 'modelVersion_id': modelVersion_id, 'ranking': ranking, "user_name": st.session_state.user_id[0], "timestamp": st.session_state.user_id[1]})
|
183 |
+
ranking_dataset.push_to_hub('MAPS-research/GEMRec-Ranking', split='train')
|
184 |
+
|
185 |
+
if progress == 'finished':
|
186 |
+
st.session_state.progress['prompt_id'] = 'finished'
|
187 |
+
else:
|
188 |
+
st.session_state.counter[prompt_id] += 1
|
189 |
+
|
190 |
+
def prev_batch(self, prompt_id):
|
191 |
+
st.session_state.counter[prompt_id] -= 1
|
192 |
|
193 |
|
194 |
if __name__ == "__main__":
|
|
|
201 |
switch_page("home")
|
202 |
|
203 |
else:
|
204 |
+
has_selection = False
|
205 |
+
|
206 |
for key, value in st.session_state.selected_dict.items():
|
207 |
for v in value:
|
208 |
+
if v:
|
209 |
+
has_selection = True
|
210 |
+
break
|
211 |
|
212 |
+
if not has_selection:
|
213 |
st.info('You have not checked any image yet. Please go back to the gallery page and check some images.')
|
214 |
gallery_btn = st.button('Go to Gallery')
|
215 |
if gallery_btn:
|
|
|
218 |
# st.write('You have checked ' + str(len(selected_modelVersions)) + ' images.')
|
219 |
roster, promptBook, images_ds = load_hf_dataset()
|
220 |
print(st.session_state.selected_dict)
|
221 |
+
# st.write("# Full function is coming soon.")
|
222 |
+
|
223 |
+
# only select the part of the promptbook where tag is the same as st.session_state.selected_dict.keys(), while model version ids are the same as corresponding values to each key
|
224 |
+
promptBook_selected = pd.DataFrame()
|
225 |
+
for key, value in st.session_state.selected_dict.items():
|
226 |
+
promptBook_selected = promptBook_selected.append(promptBook[(promptBook['prompt_id'] == key) & (promptBook['modelVersion_id'].isin(value))])
|
227 |
+
promptBook_selected = promptBook_selected.reset_index(drop=True)
|
228 |
+
# st.write(promptBook_selected)
|
229 |
+
images_endpoint = "https://modelcofferbucket.s3-accelerate.amazonaws.com/"
|
230 |
+
|
231 |
+
app = RankingApp(promptBook_selected, images_endpoint, batch_size=4)
|
232 |
+
app.app()
|
|
|
|
|
|
|
|
|
|
|
233 |
|