riccorl commited on
Commit
1103011
1 Parent(s): 8197b11

Upload app

Browse files
Files changed (1) hide show
  1. app.py +33 -26
app.py CHANGED
@@ -123,13 +123,26 @@ def set_sidebar(css):
123
 
124
  def get_el_annotations(response):
125
  # swap labels key with ents
126
- dict_of_ents = {"text": response.text, "ents": []}
127
- dict_of_ents["ents"] = response.labels
128
  label_in_text = set(l["label"] for l in dict_of_ents["ents"])
129
  options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
130
  return dict_of_ents, options
131
 
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  def set_intro(css):
134
  # intro
135
  st.markdown("# ReLik")
@@ -161,7 +174,7 @@ def run_client():
161
  # text input
162
  text = st.text_area(
163
  "Enter Text Below:",
164
- value="Obama went to Rome for a quick vacation.",
165
  height=200,
166
  max_chars=500,
167
  )
@@ -179,15 +192,9 @@ def run_client():
179
  submit = st.button("Annotate")
180
  # submit = st.button("Run")
181
 
182
- relik = Relik(
183
- question_encoder="/home/user/app/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder",
184
- document_index="/home/user/app/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered",
185
- reader="/home/user/app/models/relik-reader-aida-deberta-small",
186
- top_k=100,
187
- window_size=32,
188
- window_stride=16,
189
- candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing",
190
- )
191
 
192
  # ReLik API call
193
  if submit:
@@ -196,26 +203,26 @@ def run_client():
196
  st.markdown("####")
197
  st.markdown("#### Entity Linking")
198
  with st.spinner(text="In progress"):
199
- response = relik(text)
200
  # response = requests.post(RELIK, json=text)
201
  # if response.status_code != 200:
202
  # st.error("Error: {}".format(response.status_code))
203
  # else:
204
  # response = response.json()
205
 
206
- # Entity Linking
207
- # with stylable_container(
208
- # key="container_with_border",
209
- # css_styles="""
210
- # {
211
- # border: 1px solid rgba(49, 51, 63, 0.2);
212
- # border-radius: 0.5rem;
213
- # padding: 0.5rem;
214
- # padding-bottom: 2rem;
215
- # }
216
- # """,
217
- # ):
218
- # st.markdown("##")
219
  dict_of_ents, options = get_el_annotations(response=response)
220
  display = displacy.render(
221
  dict_of_ents, manual=True, style="ent", options=options
 
123
 
124
  def get_el_annotations(response):
125
  # swap labels key with ents
126
+ ents = [{"start": l.start, "end": l.end, "label": l.label} for l in response.labels]
127
+ dict_of_ents = {"text": response.text, "ents": ents}
128
  label_in_text = set(l["label"] for l in dict_of_ents["ents"])
129
  options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
130
  return dict_of_ents, options
131
 
132
 
133
+ @st.cache_resource()
134
+ def load_model():
135
+ return Relik(
136
+ question_encoder="/home/user/app/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder",
137
+ document_index="/home/user/app/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered",
138
+ reader="/home/user/app/models/relik-reader-aida-deberta-small",
139
+ top_k=100,
140
+ window_size=32,
141
+ window_stride=16,
142
+ candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing",
143
+ )
144
+
145
+
146
  def set_intro(css):
147
  # intro
148
  st.markdown("# ReLik")
 
174
  # text input
175
  text = st.text_area(
176
  "Enter Text Below:",
177
+ value="Michael Jordan was one of the best players in the NBA.",
178
  height=200,
179
  max_chars=500,
180
  )
 
192
  submit = st.button("Annotate")
193
  # submit = st.button("Run")
194
 
195
+ if "relik_model" not in st.session_state.keys():
196
+ st.session_state["relik_model"] = load_model()
197
+ relik_model = st.session_state["relik_model"]
 
 
 
 
 
 
198
 
199
  # ReLik API call
200
  if submit:
 
203
  st.markdown("####")
204
  st.markdown("#### Entity Linking")
205
  with st.spinner(text="In progress"):
206
+ response = relik_model(text)
207
  # response = requests.post(RELIK, json=text)
208
  # if response.status_code != 200:
209
  # st.error("Error: {}".format(response.status_code))
210
  # else:
211
  # response = response.json()
212
 
213
+ # Entity Linking
214
+ # with stylable_container(
215
+ # key="container_with_border",
216
+ # css_styles="""
217
+ # {
218
+ # border: 1px solid rgba(49, 51, 63, 0.2);
219
+ # border-radius: 0.5rem;
220
+ # padding: 0.5rem;
221
+ # padding-bottom: 2rem;
222
+ # }
223
+ # """,
224
+ # ):
225
+ # st.markdown("##")
226
  dict_of_ents, options = get_el_annotations(response=response)
227
  display = displacy.render(
228
  dict_of_ents, manual=True, style="ent", options=options