CarlosMalaga commited on
Commit
0b3d36f
1 Parent(s): 2fee2c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -48
app.py CHANGED
@@ -148,6 +148,17 @@ def get_retriever_annotations(response):
148
  label_in_text = set(l for l in dict_of_ents["ents"])
149
  options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
150
  return dict_of_ents, options
 
 
 
 
 
 
 
 
 
 
 
151
  import json
152
  io_map = {}
153
  with open("/home/user/app/models/retriever/document_index/documents.jsonl", "r") as r:
@@ -164,48 +175,66 @@ def load_model():
164
 
165
  )
166
 
167
- retriever_intervention = GoldenRetriever(
168
- question_encoder="models/retriever/level-4-small-no-negative-interventions/question_encoder",
169
- document_index="models/retriever/level-4-small-no-negative-interventions/document_index"
170
 
171
  )
172
 
173
- retriever_outcome = GoldenRetriever(
174
- question_encoder="models/retriever/level-4-small-no-negative-outcomes/question_encoder",
175
- document_index="models/retriever/level-4-small-no-negative-outcomes/document_index"
176
 
177
  )
178
 
179
- retriever_question_db = GoldenRetriever(
180
- question_encoder="/home/user/app/models/retriever/level-4-small-no-negatives/question_encoder",
181
- document_index="/home/user/app/models/retriever/level-4-small-no-negatives/document_index"
 
 
 
 
 
 
 
182
 
183
  )
184
 
185
- retriever_intervention_db = GoldenRetriever(
186
- question_encoder="models/retriever/level-4-small-no-negative-interventions/question_encoder",
187
- document_index="models/retriever/level-4-small-no-negative-interventions/document_index_db"
 
 
188
 
189
  )
190
 
191
- retriever_outcome_db = GoldenRetriever(
192
- question_encoder="models/retriever/level-4-small-no-negative-outcomes/question_encoder",
193
- document_index="models/retriever/level-4-small-no-negative-outcomes/document_index_db"
194
 
195
  )
196
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  reader = RelikReaderForSpanExtraction("/home/user/app/models/small-extended-large-batch",
199
  dataset_kwargs={"use_nme": True})
200
 
201
  relik_question = Relik(reader=reader, retriever=retriever_question, window_size="none", top_k=100, task="span", device="cpu", document_index_device="cpu")
202
- relik_intervention = Relik(reader=reader, retriever=retriever_intervention, window_size="none", top_k=100, task="span", device="cpu", document_index_device="cpu")
203
- relik_outcome = Relik(reader=reader, retriever=retriever_outcome, window_size="none", top_k=100, task="span", device="cpu", document_index_device="cpu")
204
- relik_question_db = Relik(reader=reader, retriever=retriever_question_db, window_size="none", top_k=100, task="span", device="cpu", document_index_device="cpu")
205
- relik_intrervention_db = Relik(reader=reader, retriever=retriever_intervention_db, window_size="none", top_k=100, task="span", device="cpu", document_index_device="cpu")
206
- relik_outcome_db = Relik(reader=reader, retriever=retriever_outcome_db, window_size="none", top_k=100, task="span", device="cpu", document_index_device="cpu")
207
 
208
- return [relik_question, relik_intervention, relik_outcome, relik_question_db, relik_intrervention_db, relik_outcome_db]
 
 
209
 
210
  def set_intro(css):
211
  # intro
@@ -239,10 +268,19 @@ def run_client():
239
  # Radio button selection
240
  analysis_type = st.radio(
241
  "Choose analysis type:",
242
- options=["question", "intervention", "outcome", "db intervention", "db outcome"],
243
  index=0 # Default to 'question'
244
  )
 
 
245
 
 
 
 
 
 
 
 
246
  # text input
247
  text = st.text_area(
248
  "Enter Text Below:",
@@ -273,40 +311,24 @@ def run_client():
273
  entity_linking_bool = False
274
 
275
 
276
- if analysis_type == "question":
277
  relik_model = st.session_state["relik_model"][0]
278
  entity_linking_bool = True
279
- elif analysis_type == "intervention":
280
- relik_model = st.session_state["relik_model"][1]
281
- elif analysis_type == "outcome":
282
- relik_model = st.session_state["relik_model"][2]
283
-
284
- elif analysis_type == "db intervention":
285
- relik_model = st.session_state["relik_model"][4]
286
-
287
- elif analysis_type == "db outcome":
288
- relik_model = st.session_state["relik_model"][5]
289
-
290
  else:
291
- relik_model = st.session_state["relik_model"][3]
 
 
292
 
293
  text = text.strip()
294
  if text:
295
  st.markdown("####")
296
  with st.spinner(text="In progress"):
297
- response = relik_model(text)
 
 
 
 
298
 
299
- # response = requests.post(RELIK, json=text)
300
- # if response.status_code != 200:
301
- # st.error("Error: {}".format(response.status_code))
302
- # else:
303
- # response = response.json()
304
-
305
- # st.markdown("##")
306
- dict_of_ents, options = get_el_annotations(response=response)
307
- dict_of_ents_candidates, options_candidates = get_retriever_annotations(response=response)
308
-
309
- if entity_linking_bool:
310
  st.markdown("#### Entity Linking")
311
 
312
  display = displacy.render(
@@ -329,10 +351,18 @@ def run_client():
329
 
330
  st.markdown(text, unsafe_allow_html=True)
331
  else:
 
 
 
 
 
 
 
 
332
  text = """
333
  <h2 style='color: black;'>Possible Candidates:</h2>
334
  <ul style='color: black;'>
335
- """ + "".join(f"<li style='color: black;'>{candidate}</li>" for candidate in dict_of_ents_candidates["ents"][2:12]) + "</ul>"
336
 
337
  st.markdown(text, unsafe_allow_html=True)
338
  else:
 
148
  label_in_text = set(l for l in dict_of_ents["ents"])
149
  options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
150
  return dict_of_ents, options
151
+
152
+
153
+ def get_retriever_annotations_candidates(text, ents):
154
+ el_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='https://en.wikipedia.org/wiki/{}' style='color: #414141'><i class='fa-brands fa-wikipedia-w fa-xs'></i> <span style='font-size: 1.0em; font-family: monospace'> {}</span></a>"
155
+ # swap labels key with ents
156
+ dict_of_ents = {"text": text, "ents": ents}
157
+ label_in_text = set(l for l in dict_of_ents["ents"])
158
+ options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
159
+ return dict_of_ents, options
160
+
161
+
162
  import json
163
  io_map = {}
164
  with open("/home/user/app/models/retriever/document_index/documents.jsonl", "r") as r:
 
175
 
176
  )
177
 
178
+ retriever_intervention_gpt_taxonomy = GoldenRetriever(
179
+ question_encoder="models/retriever/intervention/gpt/taxonomy/question_encoder",
180
+ document_index="models/retriever/intervention/gpt/taxonomy/document_index"
181
 
182
  )
183
 
184
+ retriever_intervention_gpt_llama_taxonomy = GoldenRetriever(
185
+ question_encoder="models/retriever/intervention/gpt+llama/taxonomy/question_encoder",
186
+ document_index="models/retriever/intervention/gpt+llama/taxonomy/document_index"
187
 
188
  )
189
 
190
+
191
+ retriever_intervention_gpt_db = GoldenRetriever(
192
+ question_encoder="models/retriever/intervention/gpt/db/question_encoder",
193
+ document_index="models/retriever/intervention/gpt/db/document_index"
194
+
195
+ )
196
+
197
+ retriever_intervention_gpt_llama_db = GoldenRetriever(
198
+ question_encoder="models/retriever/intervention/gpt+llama/db/question_encoder",
199
+ document_index="models/retriever/intervention/gpt+llama/db/document_index"
200
 
201
  )
202
 
203
+
204
+
205
+ retriever_outcome_gpt_taxonomy = GoldenRetriever(
206
+ question_encoder="models/retriever/outcome/gpt/taxonomy/question_encoder",
207
+ document_index="models/retriever/outcome/gpt/taxonomy/document_index"
208
 
209
  )
210
 
211
+ retriever_outcome_gpt_llama_taxonomy = GoldenRetriever(
212
+ question_encoder="models/retriever/outcome/gpt+llama/taxonomy/question_encoder",
213
+ document_index="models/retriever/outcome/gpt+llama/taxonomy/document_index"
214
 
215
  )
216
 
217
+
218
+ retriever_outcome_gpt_db = GoldenRetriever(
219
+ question_encoder="models/retriever/outcome/gpt/db/question_encoder",
220
+ document_index="models/retriever/outcome/gpt/db/document_index"
221
+
222
+ )
223
+
224
+ retriever_outcome_gpt_llama_db = GoldenRetriever(
225
+ question_encoder="models/retriever/outcome/gpt+llama/db/question_encoder",
226
+ document_index="models/retriever/outcome/gpt+llama/db/document_index"
227
+
228
+ )
229
 
230
  reader = RelikReaderForSpanExtraction("/home/user/app/models/small-extended-large-batch",
231
  dataset_kwargs={"use_nme": True})
232
 
233
  relik_question = Relik(reader=reader, retriever=retriever_question, window_size="none", top_k=100, task="span", device="cpu", document_index_device="cpu")
 
 
 
 
 
234
 
235
+ selection_options = ["DB Intervention (GPT)", "DB Outcome (GPT)", "DB Intervention (GPT+Llama)", "DB Outcome (GPT+Llama)", "Taxonomy Intervention (GPT)", "Taxonomy Outcome (GPT)", "Taxonomy Intervention (GPT+Llama)", "Taxonomy Outcome (GPT+Llama)"]
236
+
237
+ return [relik_question, retriever_intervention_gpt_db, retriever_outcome_gpt_db, retriever_intervention_gpt_llama_db, retriever_outcome_gpt_llama_db, retriever_intervention_gpt_taxonomy, retriever_outcome_gpt_taxonomy, retriever_intervention_gpt_llama_taxonomy, retriever_outcome_gpt_llama_taxonomy]
238
 
239
  def set_intro(css):
240
  # intro
 
268
  # Radio button selection
269
  analysis_type = st.radio(
270
  "Choose analysis type:",
271
+ options=["Retriever", "Entity Linking"],
272
  index=0 # Default to 'question'
273
  )
274
+
275
+ selection_options = ["DB Intervention (GPT)", "DB Outcome (GPT)", "DB Intervention (GPT+Llama)", "DB Outcome (GPT+Llama)", "Taxonomy Intervention (GPT)", "Taxonomy Outcome (GPT)", "Taxonomy Intervention (GPT+Llama)", "Taxonomy Outcome (GPT+Llama)"]
276
 
277
+ if analysis_type == "Retriever"
278
+ # Selection list using selectbox
279
+ selection_list = st.selectbox(
280
+ "Select an option:",
281
+ options=options
282
+ )
283
+
284
  # text input
285
  text = st.text_area(
286
  "Enter Text Below:",
 
311
  entity_linking_bool = False
312
 
313
 
314
+ if analysis_type == "Entity Linking":
315
  relik_model = st.session_state["relik_model"][0]
316
  entity_linking_bool = True
 
 
 
 
 
 
 
 
 
 
 
317
  else:
318
+ model_idx = selection_options.index(selection_list)
319
+ relik_model = st.session_state["relik_model"][model_idx+1]
320
+
321
 
322
  text = text.strip()
323
  if text:
324
  st.markdown("####")
325
  with st.spinner(text="In progress"):
326
+ if entity_linking_bool:
327
+ response = relik_model(text)
328
+
329
+ dict_of_ents, options = get_el_annotations(response=response)
330
+ dict_of_ents_candidates, options_candidates = get_retriever_annotations(response=response)
331
 
 
 
 
 
 
 
 
 
 
 
 
332
  st.markdown("#### Entity Linking")
333
 
334
  display = displacy.render(
 
351
 
352
  st.markdown(text, unsafe_allow_html=True)
353
  else:
354
+ response = relik_model.retrieve(text, k=10, batch_size=100, progress_bar=False)
355
+
356
+ candidates_text = []
357
+ for pred in response[0]:
358
+ candidates.append(pred.document.text)
359
+
360
+ dict_of_ents_candidates, options_candidates = get_retriever_annotations_candidates(text, candidates_text)
361
+
362
  text = """
363
  <h2 style='color: black;'>Possible Candidates:</h2>
364
  <ul style='color: black;'>
365
+ """ + "".join(f"<li style='color: black;'>{candidate}</li>" for candidate in dict_of_ents_candidates["ents"][0:10]) + "</ul>"
366
 
367
  st.markdown(text, unsafe_allow_html=True)
368
  else: