Ricercar commited on
Commit
367d6d8
·
1 Parent(s): 6febf41

add new gallery view, add selectable images

Browse files
Files changed (1) hide show
  1. app.py +85 -15
app.py CHANGED
@@ -18,7 +18,7 @@ class GalleryApp:
18
  self.promptBook = promptBook
19
  st.set_page_config(layout="wide")
20
 
21
- def gallery(self, items, col_num, info):
22
  cols = st.columns(col_num)
23
  # # sort items by brisque score
24
  # items = items.sort_values(by=['brisque'], ascending=True).reset_index(drop=True)
@@ -34,23 +34,49 @@ class GalleryApp:
34
  # st.image(image, use_column_width=True)
35
  # with tab2:
36
  # st.image(image, use_column_width=True)
 
 
 
 
 
 
37
  for key in info:
38
  st.write(f"**{key}**: {items.iloc[idx][key]}")
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def app(self):
41
  st.title('Model Coffer Gallery')
42
  st.write('This is a gallery of images generated by the models in the Model Coffer')
43
 
44
- # metadata, images = st.columns([1, 3])
45
- # with images:
46
- # prompt_tags = self.promptBook['tag'].unique()
47
- # # sort tags by alphabetical order
48
- # prompt_tags = np.sort(prompt_tags)
49
- #
50
- # selecters = st.columns(3)
51
- # with selecters[0]:
52
- # tag = st.selectbox('Select a tag', prompt_tags)
53
-
54
  with st.sidebar:
55
  prompt_tags = self.promptBook['tag'].unique()
56
  # sort tags by alphabetical order
@@ -96,11 +122,11 @@ class GalleryApp:
96
 
97
  res = requests.get(f'https://civitai.com/images', params={'post_id': prompt_id})
98
  st.write(res)
99
- # image_url = res.json()['items'][0]['url']
100
- # st.image(image_url, use_column_width=True)
101
 
102
  # with images:
103
- selecters = st.columns([1, 1, 2])
104
 
105
  with selecters[0]:
106
  # sort_by = st.selectbox('Sort by', items.columns[11: -1])
@@ -116,6 +142,7 @@ class GalleryApp:
116
  order = False
117
 
118
  items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
 
119
  with selecters[2]:
120
  info = st.multiselect('Show Info',
121
  ['model_download_count', 'model_name', 'model_id',
@@ -123,7 +150,46 @@ class GalleryApp:
123
  default=sort_by)
124
 
125
  col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
126
- self.gallery(items, col_num, info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
 
129
  if __name__ == '__main__':
@@ -140,6 +206,10 @@ if __name__ == '__main__':
140
  print('loading promptBook')
141
 
142
  st.session_state.promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
 
 
 
 
143
  st.session_state.images = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
144
  print(st.session_state.images)
145
  print('images loaded')
 
18
  self.promptBook = promptBook
19
  st.set_page_config(layout="wide")
20
 
21
+ def gallery_masonry(self, items, col_num, info):
22
  cols = st.columns(col_num)
23
  # # sort items by brisque score
24
  # items = items.sort_values(by=['brisque'], ascending=True).reset_index(drop=True)
 
34
  # st.image(image, use_column_width=True)
35
  # with tab2:
36
  # st.image(image, use_column_width=True)
37
+
38
+ # show checkbox
39
+ self.promptBook.loc[items.iloc[idx]['row_idx'].item(), 'checked'] = st.checkbox(
40
+ 'Select', value=self.promptBook.loc[items.iloc[idx]['row_idx'].item(), 'checked'],
41
+ key=f'select_{idx}')
42
+
43
  for key in info:
44
  st.write(f"**{key}**: {items.iloc[idx][key]}")
45
 
46
+ def gallery_standard(self, items, col_num, info):
47
+ rows = len(items) // col_num + 1
48
+ containers = [st.container() for _ in range(rows*2)]
49
+ for idx in range(0, len(items), col_num):
50
+ # assign one container for each row
51
+ row_idx = (idx // col_num) * 2
52
+ with containers[row_idx]:
53
+ cols = st.columns(col_num)
54
+ for j in range(col_num):
55
+ if idx + j < len(items):
56
+ with cols[j]:
57
+ # show image
58
+ image = st.session_state.images[items.iloc[idx+j]['row_idx'].item()]['image']
59
+ st.image(image,
60
+ use_column_width=True,
61
+ )
62
+
63
+ # show checkbox
64
+ self.promptBook.loc[items.iloc[idx+j]['row_idx'].item(), 'checked'] = st.checkbox('Select', value=self.promptBook.loc[items.iloc[idx+j]['row_idx'].item(), 'checked'] , key=f'select_{idx+j}')
65
+
66
+ # show selected info
67
+ for key in info:
68
+ st.write(f"**{key}**: {items.iloc[idx+j][key]}")
69
+
70
+ # st.write(row_idx/2, idx+j, rows)
71
+ # extra_info = st.checkbox('Extra Info', key=f'extra_info_{idx+j}')
72
+ # if extra_info:
73
+ # with containers[row_idx+1]:
74
+ # st.image(image, use_column_width=True)
75
+
76
  def app(self):
77
  st.title('Model Coffer Gallery')
78
  st.write('This is a gallery of images generated by the models in the Model Coffer')
79
 
 
 
 
 
 
 
 
 
 
 
80
  with st.sidebar:
81
  prompt_tags = self.promptBook['tag'].unique()
82
  # sort tags by alphabetical order
 
122
 
123
  res = requests.get(f'https://civitai.com/images', params={'post_id': prompt_id})
124
  st.write(res)
125
+ image_url = res.json()['items'][0]['url']
126
+ st.image(image_url, use_column_width=True)
127
 
128
  # with images:
129
+ selecters = st.columns([1, 1, 2, 0.5])
130
 
131
  with selecters[0]:
132
  # sort_by = st.selectbox('Sort by', items.columns[11: -1])
 
142
  order = False
143
 
144
  items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
145
+
146
  with selecters[2]:
147
  info = st.multiselect('Show Info',
148
  ['model_download_count', 'model_name', 'model_id',
 
150
  default=sort_by)
151
 
152
  col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
153
+
154
+ with selecters[3]:
155
+ filter = st.selectbox('Filter', ['All', 'Checked', 'Unchecked'])
156
+ if filter == 'Checked':
157
+ items = items[items['checked'] == True].reset_index(drop=True)
158
+ elif filter == 'Unchecked':
159
+ items = items[items['checked'] == False].reset_index(drop=True)
160
+
161
+ with st.form(key='my_form'):
162
+ buttons = st.columns([1, 1, 1])
163
+ with buttons[0]:
164
+ submit = st.form_submit_button('Save selections', on_click=self.save_checked, use_container_width=True, type='primary')
165
+ with buttons[1]:
166
+ submit = st.form_submit_button('Reset current prompt', on_click=self.reset_current_prompt, kwargs={'prompt_id': prompt_id} , use_container_width=True)
167
+ with buttons[2]:
168
+ submit = st.form_submit_button('Reset all selections', on_click=self.reset_all, use_container_width=True)
169
+ self.gallery_standard(items, col_num, info)
170
+
171
+ def reset_current_prompt(self, prompt_id):
172
+ # reset current prompt
173
+ self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
174
+ self.save_checked()
175
+
176
+ def reset_all(self):
177
+ # reset all
178
+ self.promptBook.loc[:, 'checked'] = False
179
+ self.save_checked()
180
+
181
+ def save_checked(self):
182
+ # save checked images to huggingface dataset
183
+ dataset = load_dataset('NYUSHPRP/ModelCofferMetadata', split='train')
184
+ # get checked images
185
+ checked_info = self.promptBook['checked']
186
+
187
+ if 'checked' in dataset.column_names:
188
+ dataset = dataset.remove_columns('checked')
189
+ dataset = dataset.add_column('checked', checked_info)
190
+
191
+ # print('metadata dataset: ', dataset)
192
+ dataset.push_to_hub('NYUSHPRP/ModelCofferMetadata', split='train')
193
 
194
 
195
  if __name__ == '__main__':
 
206
  print('loading promptBook')
207
 
208
  st.session_state.promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
209
+ # add 'checked' column to promptBook if not exist
210
+ if 'checked' not in st.session_state.promptBook.columns:
211
+ st.session_state.promptBook['checked'] = False
212
+
213
  st.session_state.images = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
214
  print(st.session_state.images)
215
  print('images loaded')