CarlosMalaga commited on
Commit
c047b4f
1 Parent(s): bb614f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -4
app.py CHANGED
@@ -173,12 +173,56 @@ def load_model():
173
  )
174
  retriever.index()
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  reader = RelikReaderForSpanExtraction("/home/user/app/models/small-extended-large-batch",
177
  dataset_kwargs={"use_nme": True})
178
 
179
- relik = Relik(reader=reader, retriever=retriever, window_size=32, window_stride=16, top_k=100, task="span", device="cpu", document_index_device="cpu")
 
 
 
 
 
180
 
181
- return relik
182
 
183
  def set_intro(css):
184
  # intro
@@ -209,6 +253,13 @@ def run_client():
209
  set_sidebar(css)
210
  set_intro(css)
211
 
 
 
 
 
 
 
 
212
  # text input
213
  text = st.text_area(
214
  "Enter Text Below:",
@@ -232,10 +283,30 @@ def run_client():
232
 
233
  if "relik_model" not in st.session_state.keys():
234
  st.session_state["relik_model"] = load_model()
235
- relik_model = st.session_state["relik_model"]
236
 
237
  # ReLik API call
238
  if submit:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  text = text.strip()
240
  if text:
241
  st.markdown("####")
@@ -269,7 +340,7 @@ def run_client():
269
  text = """
270
  <h2 style='color: black;'>Possible Candidates:</h2>
271
  <ul style='color: black;'>
272
- """ + "".join(f"<li>{candidate}</li>" for candidate in dict_of_ents_candidates["ents"][2:12]) + "</ul>"
273
 
274
  st.markdown(text, unsafe_allow_html=True)
275
 
 
173
  )
174
  retriever.index()
175
 
176
+
177
+
178
+ retriever_question = GoldenRetriever(
179
+ question_encoder="/home/user/app/models/retriever/question_encoder",
180
+ document_index="home/user/app/models/retriever/document_index/questions"
181
+
182
+ )
183
+
184
+ retriever_intervention = GoldenRetriever(
185
+ question_encoder="/home/user/app/models/retriever/question_encoder",
186
+ document_index="home/user/app/models/retriever/document_index/interventions"
187
+
188
+ )
189
+
190
+ retriever_outcome = GoldenRetriever(
191
+ question_encoder="/home/user/app/models/retriever/question_encoder",
192
+ document_index="home/user/app/models/retriever/document_index/outcomes"
193
+
194
+ )
195
+
196
+ retriever_question_db = GoldenRetriever(
197
+ question_encoder="/home/user/app/models/retriever/question_encoder",
198
+ document_index="home/user/app/models/retriever/document_index/question_db"
199
+
200
+ )
201
+
202
+ retriever_intervention_db = GoldenRetriever(
203
+ question_encoder="/home/user/app/models/retriever/question_encoder",
204
+ document_index="home/user/app/models/retriever/document_index/interventions_db"
205
+
206
+ )
207
+
208
+ retriever_outcome_db = GoldenRetriever(
209
+ question_encoder="/home/user/app/models/retriever/question_encoder",
210
+ document_index="home/user/app/models/retriever/document_index/outcomes_db"
211
+
212
+ )
213
+
214
+
215
  reader = RelikReaderForSpanExtraction("/home/user/app/models/small-extended-large-batch",
216
  dataset_kwargs={"use_nme": True})
217
 
218
+ relik_question = Relik(reader=reader, retriever=retriever_question, window_size="none", top_k=100, task="span", device="cuda", document_index_device="cpu")
219
+ relik_intervention = Relik(reader=reader, retriever=retriever_intervention, window_size="none", top_k=100, task="span", device="cuda", document_index_device="cpu")
220
+ relik_outcome = Relik(reader=reader, retriever=retriever_outcome, window_size="none", top_k=100, task="span", device="cuda", document_index_device="cpu")
221
+ relik_question_db = Relik(reader=reader, retriever=retriever_question_db, window_size="none", top_k=100, task="span", device="cuda", document_index_device="cpu")
222
+ relik_intrervention_db = Relik(reader=reader, retriever=retriever_intervention_db, window_size="none", top_k=100, task="span", device="cuda", document_index_device="cpu")
223
+ relik_outcome_db = Relik(reader=reader, retriever=retriever_outcome_db, window_size="none", top_k=100, task="span", device="cuda", document_index_device="cpu")
224
 
225
+ return [relik_question, relik_intervention, relik_outcome, relik_question_db, relik_intrervention_db, relik_outcome_db]
226
 
227
  def set_intro(css):
228
  # intro
 
253
  set_sidebar(css)
254
  set_intro(css)
255
 
256
+ # Radio button selection
257
+ analysis_type = st.radio(
258
+ "Choose analysis type:",
259
+ options=["intervention", "outcome", "question", "db intervention", "db outcome", "db question"],
260
+ index=2 # Default to 'question'
261
+ )
262
+
263
  # text input
264
  text = st.text_area(
265
  "Enter Text Below:",
 
283
 
284
  if "relik_model" not in st.session_state.keys():
285
  st.session_state["relik_model"] = load_model()
286
+ relik_model = st.session_state["relik_model"][0]
287
 
288
  # ReLik API call
289
  if submit:
290
+ if analysis_type == "question":
291
+ relik_model = st.session_state["relik_model"][0]
292
+
293
+ elif analysis_type == "intervention":
294
+ relik_model = st.session_state["relik_model"][1]
295
+ elif analysis_type == "outcome":
296
+ relik_model = st.session_state["relik_model"][2]
297
+ elif analysis_type == "db question":
298
+ relik_model = st.session_state["relik_model"][3]
299
+
300
+ elif analysis_type == "db intervention":
301
+ relik_model = st.session_state["relik_model"][4]
302
+
303
+ elif analysis_type == "db outcome":
304
+ print("hola")
305
+ relik_model = st.session_state["relik_model"][5]
306
+
307
+ else:
308
+ relik_model = st.session_state["relik_model"][0]
309
+
310
  text = text.strip()
311
  if text:
312
  st.markdown("####")
 
340
  text = """
341
  <h2 style='color: black;'>Possible Candidates:</h2>
342
  <ul style='color: black;'>
343
+ """ + "".join(f"<li style='color: black;'>{candidate}</li>" for candidate in dict_of_ents_candidates["ents"][2:12]) + "</ul>"
344
 
345
  st.markdown(text, unsafe_allow_html=True)
346