Ricercar commited on
Commit
c8f09d8
1 Parent(s): b81f75d

add connections to database

Browse files
Files changed (4) hide show
  1. Home.py +2 -1
  2. data/ranking_script.py +2 -1
  3. pages/Ranking.py +30 -5
  4. requirements.txt +2 -1
Home.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import streamlit as st
2
  import random
3
  import time
@@ -28,7 +29,7 @@ def save_user_id(user_id):
28
  print(user_id)
29
  if not user_id:
30
  user_id = 'anonymous' + str(random.randint(0, 100000))
31
- st.session_state.user_id = [user_id, time.time()]
32
 
33
 
34
  def logout():
 
1
+ from datetime import datetime
2
  import streamlit as st
3
  import random
4
  import time
 
29
  print(user_id)
30
  if not user_id:
31
  user_id = 'anonymous' + str(random.randint(0, 100000))
32
+ st.session_state.user_id = [user_id, datetime.now().strftime("%Y-%m-%d %H:%M:%S")]
33
 
34
 
35
  def logout():
data/ranking_script.py CHANGED
@@ -1,4 +1,5 @@
1
  from datasets import Dataset
 
2
 
3
 
4
  def init_ranking_data():
@@ -6,7 +7,7 @@ def init_ranking_data():
6
 
7
  # add example data
8
  # note that image_id is a string, other ids are int
9
- ds = ds.add_item({'image_id': '0', 'modelVersion_id': 0, 'ranking': 0, "user_name": "example_data", "timestamp": 0.0})
10
 
11
  ds.push_to_hub("MAPS-research/GEMRec-Ranking", split='train')
12
 
 
1
  from datasets import Dataset
2
+ from datetime import datetime
3
 
4
 
5
  def init_ranking_data():
 
7
 
8
  # add example data
9
  # note that image_id is a string, other ids are int
10
+ ds = ds.add_item({'image_id': '0', 'modelVersion_id': 0, 'ranking': 0, "user_name": "example_data", "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")})
11
 
12
  ds.push_to_hub("MAPS-research/GEMRec-Ranking", split='train')
13
 
pages/Ranking.py CHANGED
@@ -1,6 +1,9 @@
 
 
1
  import datasets
2
  import numpy as np
3
  import pandas as pd
 
4
  import streamlit as st
5
 
6
  from streamlit_elements import elements, mui, html, dashboard, nivo
@@ -165,22 +168,28 @@ class RankingApp:
165
 
166
  restart_btn = st.button('🎖️ Rank Again')
167
  if restart_btn:
168
- st.session_state.progress['prompt_id'] = 'ranking'
169
  st.session_state.counter[prompt_id] = 0
170
  st.experimental_rerun()
171
 
172
 
173
  def next_batch(self, prompt_id, progress=None):
174
-
175
  # save ranking to dataset
176
  # print(st.session_state.ranking)
177
- ranking_dataset = datasets.load_dataset('MAPS-research/GEMRec-Ranking', split='train')
 
178
  for image_id in st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]].keys():
179
  modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_id]['modelVersion_id'].values[0]
180
  ranking = st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]][image_id]
181
  # print({'image_id': image_id, 'modelVersion_id': modelVersion_id, 'ranking': ranking, "user_name": st.session_state.user_id[0], "timestamp": st.session_state.user_id[1]})
182
- ranking_dataset = ranking_dataset.add_item({'image_id': image_id, 'modelVersion_id': modelVersion_id, 'ranking': ranking, "user_name": st.session_state.user_id[0], "timestamp": st.session_state.user_id[1]})
183
- ranking_dataset.push_to_hub('MAPS-research/GEMRec-Ranking', split='train')
 
 
 
 
 
 
184
 
185
  if progress == 'finished':
186
  st.session_state.progress[prompt_id] = 'finished'
@@ -191,6 +200,20 @@ class RankingApp:
191
  st.session_state.counter[prompt_id] -= 1
192
 
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  if __name__ == "__main__":
195
  st.set_page_config(page_title="Personal Image Ranking", page_icon="🎖️️", layout="wide")
196
 
@@ -218,7 +241,9 @@ if __name__ == "__main__":
218
  # st.write('You have checked ' + str(len(selected_modelVersions)) + ' images.')
219
  roster, promptBook, images_ds = load_hf_dataset()
220
  print(st.session_state.selected_dict)
 
221
  # st.write("# Full function is coming soon.")
 
222
 
223
  # only select the part of the promptbook where tag is the same as st.session_state.selected_dict.keys(), while model version ids are the same as corresponding values to each key
224
  promptBook_selected = pd.DataFrame()
 
1
+ import os
2
+
3
  import datasets
4
  import numpy as np
5
  import pandas as pd
6
+ import pymysql.cursors
7
  import streamlit as st
8
 
9
  from streamlit_elements import elements, mui, html, dashboard, nivo
 
168
 
169
  restart_btn = st.button('🎖️ Rank Again')
170
  if restart_btn:
171
+ st.session_state.progress[prompt_id] = 'ranking'
172
  st.session_state.counter[prompt_id] = 0
173
  st.experimental_rerun()
174
 
175
 
176
  def next_batch(self, prompt_id, progress=None):
 
177
  # save ranking to dataset
178
  # print(st.session_state.ranking)
179
+ # ranking_dataset = datasets.load_dataset('MAPS-research/GEMRec-Ranking', split='train')
180
+ curser = RANKING_CONN.cursor()
181
  for image_id in st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]].keys():
182
  modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_id]['modelVersion_id'].values[0]
183
  ranking = st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]][image_id]
184
  # print({'image_id': image_id, 'modelVersion_id': modelVersion_id, 'ranking': ranking, "user_name": st.session_state.user_id[0], "timestamp": st.session_state.user_id[1]})
185
+ # ranking_dataset = ranking_dataset.add_item({'image_id': image_id, 'modelVersion_id': modelVersion_id, 'ranking': ranking, "user_name": st.session_state.user_id[0], "timestamp": st.session_state.user_id[1]})
186
+
187
+ query = "INSERT INTO rankings (image_id, modelVersion_id, ranking, user_name, timestamp) VALUES (%s, %s, %s, %s, %s)"
188
+ curser.execute(query, (image_id, modelVersion_id, ranking, st.session_state.user_id[0], st.session_state.user_id[1]))
189
+
190
+ curser.close()
191
+ RANKING_CONN.commit()
192
+ # ranking_dataset.push_to_hub('MAPS-research/GEMRec-Ranking', split='train')
193
 
194
  if progress == 'finished':
195
  st.session_state.progress[prompt_id] = 'finished'
 
200
  st.session_state.counter[prompt_id] -= 1
201
 
202
 
203
+ def connect_to_db():
204
+ conn = pymysql.connect(
205
+ host=os.environ.get('RANKING_DB_HOST'),
206
+ port=3306,
207
+ database='myRanking',
208
+ user=os.environ.get('RANKING_DB_USER'),
209
+ password=os.environ.get('RANKING_DB_PASSWORD'),
210
+ charset='utf8mb4',
211
+ cursorclass=pymysql.cursors.DictCursor
212
+ )
213
+
214
+ return conn
215
+
216
+
217
  if __name__ == "__main__":
218
  st.set_page_config(page_title="Personal Image Ranking", page_icon="🎖️️", layout="wide")
219
 
 
241
  # st.write('You have checked ' + str(len(selected_modelVersions)) + ' images.')
242
  roster, promptBook, images_ds = load_hf_dataset()
243
  print(st.session_state.selected_dict)
244
+
245
  # st.write("# Full function is coming soon.")
246
+ RANKING_CONN = connect_to_db()
247
 
248
  # only select the part of the promptbook where tag is the same as st.session_state.selected_dict.keys(), while model version ids are the same as corresponding values to each key
249
  promptBook_selected = pd.DataFrame()
requirements.txt CHANGED
@@ -3,4 +3,5 @@ streamlit-elements==0.1.0
3
  streamlit-extras
4
  altair<5
5
  streamlit-vega-lite
6
- scikit-learn
 
 
3
  streamlit-extras
4
  altair<5
5
  streamlit-vega-lite
6
+ scikit-learn
7
+ pymysql