malteos commited on
Commit
7a888ed
β€’
1 Parent(s): 42610a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -101
app.py CHANGED
@@ -69,15 +69,10 @@ def st_load_dataset(name_or_path):
69
  if isinstance(dataset, DatasetDict):
70
  dataset = dataset['train']
71
 
72
- # load existing faiss
73
  for a in aspects:
74
  dataset.load_faiss_index(f'{a}_embeddings', f'{a}_embeddings.faiss')
75
 
76
- # add faiss
77
- #dataset.add_faiss_index(column=f'{aspect}_embeddings')
78
- #loaded_dataset.add_faiss_index(column='method_embeddings')
79
- #loaded_dataset.add_faiss_index(column='dataset_embeddings')
80
-
81
  return dataset
82
 
83
 
@@ -99,64 +94,58 @@ def get_paper(doc_id):
99
 
100
 
101
  def find_related_papers(paper_id, user_aspect):
102
- # Add result to session
103
-
104
- paper = get_paper(paper_id)
105
 
106
- if paper is None or 'title' not in paper or 'abstract' not in paper:
107
- raise ValueError('Could not retrieve data for input paper')
108
 
109
- title_abs = paper['title'] + ': ' + paper['abstract']
 
110
 
111
- # preprocess the input
112
- inputs = tokenizer(title_abs, padding=True, truncation=True, return_tensors="pt", max_length=512)
113
 
114
- # inference
115
- outputs = aspect_to_model[user_aspect](**inputs)
116
 
117
- # logger.info(f'attention_mask: {inputs["attention_mask"].shape}')
118
- #
119
- # logger.info(f'Outputs: {outputs["last_hidden_state"]}')
120
- # logger.info(f'Outputs: {outputs["last_hidden_state"].shape}')
121
 
122
- # Mean pool the token-level embeddings to get sentence-level embeddings
123
- embeddings = torch.sum(
124
- outputs["last_hidden_state"] * inputs['attention_mask'].unsqueeze(-1), dim=1
125
- ) / torch.clamp(torch.sum(inputs['attention_mask'], dim=1, keepdims=True), min=1e-9)
126
 
127
- result = dict(
128
- paper=paper,
129
- aspect=user_aspect,
130
- )
131
 
132
- result.update(dict(
133
- #embeddings=embeddings.tolist(),
134
- ))
 
135
 
136
- # Retrieval
137
- prompt = embeddings.detach().numpy()[0]
138
- scores, retrieved_examples = dataset.get_nearest_examples(f'{user_aspect}_embeddings', prompt, k=10)
139
 
140
- result.update(dict(
141
- related_papers=retrieved_examples,
142
- ))
143
 
144
- # st.session_state.results.append(result)
 
 
145
 
146
  return result
147
 
148
 
149
- # # Start session
150
- # if 'results' not in st.session_state:
151
- # st.session_state.results = []
152
-
153
  # Page
154
  st.title('Aspect-based Paper Similarity')
155
  st.markdown("""This demo showcases [Specialized Document Embeddings for Aspect-based Research Paper Similarity](#TODO).""")
156
 
157
  # Introduction
158
- st.markdown(f"""The model was trained using a triplet loss on machine learning papers from the [paperswithcode.com](https://paperswithcode.com/) corpus with the objective of pulling embeddings of papers with the same task, method, or datasetclose together. For a more comprehensive overview of the model check out the [model card on πŸ€— Model Hub]({model_hub_url}) or read [our paper](#TODO).
159
- """)
160
  st.markdown("""Enter a ArXiv ID or a DOI of a paper for that you want find similar papers.
161
 
162
  Try it yourself! πŸ‘‡""",
@@ -170,19 +159,20 @@ with st.form("aspect-input", clear_on_submit=False):
170
  placeholder='Any DOI, ACL, or ArXiv ID'
171
  )
172
 
 
 
 
 
 
 
 
 
173
  example = st.selectbox(
174
- label='Or select example',
175
- options=[
176
- "arXiv:2202.06671",
177
- '10.1016/j.eswa.2019.06.026'
178
- ]
179
  )
180
 
181
- # click_clear = st.button('clear text input', key=1)
182
- # if click_clear:
183
- # paper_id = st.text_input(
184
- # label='Enter paper ID (arXiv:<arxiv_id>, or <doi>):', value="XXX", placeholder='123')
185
-
186
  user_aspect = st.radio(
187
  label="In what aspect are you interested?",
188
  options=aspects
@@ -194,61 +184,29 @@ with st.form("aspect-input", clear_on_submit=False):
194
  # Listener
195
  if submitted:
196
  if paper_id or example:
197
- with st.spinner('Finding related papers...'):
198
- try:
199
- result = find_related_papers(paper_id if paper_id else example, user_aspect)
200
 
201
- input_paper = result['paper']
202
- related_papers = result['related_papers']
203
 
204
- # with st.empty():
205
 
206
- st.markdown(
207
- f'''Your input paper: \n\n<a href="{input_paper['url']}"><b>{input_paper['title']}</b></a> ({input_paper['year']})<hr />''',
208
- unsafe_allow_html=True)
209
 
210
- related_html = '<ul>'
211
 
212
- for i in range(len(related_papers['paper_id'])):
213
- related_html += f'''<li><a href="{related_papers['url_abs'][i]}">{related_papers['title'][i]}</a></li>'''
214
 
215
- related_html += '</ul>'
216
 
217
- st.markdown(f'''Related papers with similar {result['aspect']}: {related_html}''', unsafe_allow_html=True)
218
 
219
- except (TypeError, ValueError, KeyError) as e:
220
- st.error(f'**Error**: {e}')
221
 
222
  else:
223
  st.error('**Error**: No paper ID provided. Please provide a ArXiv ID or DOI.')
224
-
225
- # # Results
226
- # if 'results' in st.session_state and st.session_state.results:
227
- # first = True
228
- # for result in st.session_state.results[::-1]:
229
- # if not first:
230
- # st.markdown("---")
231
- # # st.markdown(f"ID:\n> {result['paperId']}")
232
- # # col_1, col_2, col_3 = st.columns([1,2,2])
233
- # # col_1.metric(label='', value=json.dumps(result))
234
- # # col_2.metric(label='Label', value=f"fooo")
235
- # # col_3.metric(label='Score', value=f"123")
236
- # input_paper = result['paper']
237
- # related_papers = result['related_papers']
238
- #
239
- # # with st.empty():
240
- #
241
- # st.markdown(f'''Your input paper: \n\n<a href="{input_paper['url']}"><b>{input_paper['title']}</b></a> ({input_paper['year']})<hr />''', unsafe_allow_html=True)
242
- #
243
- # related_html = '<ul>'
244
- #
245
- # for i in range(len(related_papers['paper_id'])):
246
- # related_html += f'''<li><a href="{related_papers['url_abs'][i]}">{related_papers['title'][i]}</a></li>'''
247
- #
248
- # related_html += '</ul>'
249
- #
250
- # st.markdown(f'''Related papers with similar {result['aspect']}: {related_html}''', unsafe_allow_html=True)
251
- #
252
- # # st.markdown(f'''Related papers: {related_html}''', unsafe_allow_html=True)
253
- #
254
- # first = False
 
69
  if isinstance(dataset, DatasetDict):
70
  dataset = dataset['train']
71
 
72
+ # load existing FAISS index for each aspect
73
  for a in aspects:
74
  dataset.load_faiss_index(f'{a}_embeddings', f'{a}_embeddings.faiss')
75
 
 
 
 
 
 
76
  return dataset
77
 
78
 
 
94
 
95
 
96
  def find_related_papers(paper_id, user_aspect):
97
+ with st.spinner('Searching for related papers...'):
 
 
98
 
99
+ paper = get_paper(paper_id)
 
100
 
101
+ if paper is None or 'title' not in paper or paper['title'] is None or 'abstract' not in paper or paper['abstract'] is None:
102
+ raise ValueError(f'Could not retrieve title and abstract for input paper: {paper_id}')
103
 
104
+ title_abs = paper['title'] + ': ' + paper['abstract']
 
105
 
106
+ # preprocess the input
107
+ inputs = tokenizer(title_abs, padding=True, truncation=True, return_tensors="pt", max_length=512)
108
 
109
+ # inference
110
+ outputs = aspect_to_model[user_aspect](**inputs)
 
 
111
 
112
+ # logger.info(f'attention_mask: {inputs["attention_mask"].shape}')
113
+ #
114
+ # logger.info(f'Outputs: {outputs["last_hidden_state"]}')
115
+ # logger.info(f'Outputs: {outputs["last_hidden_state"].shape}')
116
 
117
+ # Mean pool the token-level embeddings to get sentence-level embeddings
118
+ embeddings = torch.sum(
119
+ outputs["last_hidden_state"] * inputs['attention_mask'].unsqueeze(-1), dim=1
120
+ ) / torch.clamp(torch.sum(inputs['attention_mask'], dim=1, keepdims=True), min=1e-9)
121
 
122
+ result = dict(
123
+ paper=paper,
124
+ aspect=user_aspect,
125
+ )
126
 
127
+ result.update(dict(
128
+ #embeddings=embeddings.tolist(),
129
+ ))
130
 
131
+ # Retrieval
132
+ prompt = embeddings.detach().numpy()[0]
133
+ scores, retrieved_examples = dataset.get_nearest_examples(f'{user_aspect}_embeddings', prompt, k=10)
134
 
135
+ result.update(dict(
136
+ related_papers=retrieved_examples,
137
+ ))
138
 
139
  return result
140
 
141
 
 
 
 
 
142
  # Page
143
  st.title('Aspect-based Paper Similarity')
144
  st.markdown("""This demo showcases [Specialized Document Embeddings for Aspect-based Research Paper Similarity](#TODO).""")
145
 
146
  # Introduction
147
+ st.markdown(f"""The model was trained using a triplet loss on machine learning papers from the [paperswithcode.com](https://paperswithcode.com/) corpus with the objective of pulling embeddings of papers with the same task, method, or datasetclose together.
148
+ For a more comprehensive overview of the model check out the [model card on πŸ€— Model Hub]({model_hub_url}) or read [our paper](#TODO).""")
149
  st.markdown("""Enter a ArXiv ID or a DOI of a paper for that you want find similar papers.
150
 
151
  Try it yourself! πŸ‘‡""",
 
159
  placeholder='Any DOI, ACL, or ArXiv ID'
160
  )
161
 
162
+ example_labels = {
163
+ "arXiv:1902.06818": "Data augmentation for low resource sentiment analysis using generative adversarial networks",
164
+ "arXiv:2202.06671": "Neighborhood Contrastive Learning for Scientific Document Representations with Citation Embeddings",
165
+ "ACL:N19-1423": "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding",
166
+ "10.18653/v1/S16-1001": "SemEval-2016 Task 4: Sentiment Analysis in Twitter",
167
+ "10.1145/3065386": "ImageNet classification with deep convolutional neural networks",
168
+ }
169
+
170
  example = st.selectbox(
171
+ label='Or select an example:',
172
+ options=list(example_labels.keys()),
173
+ format_func=lambda option_key: f'{example_labels[option_key]} ({option_key})',
 
 
174
  )
175
 
 
 
 
 
 
176
  user_aspect = st.radio(
177
  label="In what aspect are you interested?",
178
  options=aspects
 
184
  # Listener
185
  if submitted:
186
  if paper_id or example:
187
+ try:
188
+ result = find_related_papers(paper_id if paper_id else example, user_aspect)
 
189
 
190
+ input_paper = result['paper']
191
+ related_papers = result['related_papers']
192
 
193
+ # with st.empty():
194
 
195
+ st.markdown(
196
+ f'''Your input paper: \n\n<a href="{input_paper['url']}"><b>{input_paper['title']}</b></a> ({input_paper['year']})<hr />''',
197
+ unsafe_allow_html=True)
198
 
199
+ related_html = '<ul>'
200
 
201
+ for i in range(len(related_papers['paper_id'])):
202
+ related_html += f'''<li><a href="{related_papers['url_abs'][i]}">{related_papers['title'][i]}</a></li>'''
203
 
204
+ related_html += '</ul>'
205
 
206
+ st.markdown(f'''Related papers with similar {result['aspect']}: {related_html}''', unsafe_allow_html=True)
207
 
208
+ except (TypeError, ValueError, KeyError) as e:
209
+ st.error(f'**Error**: {e}')
210
 
211
  else:
212
  st.error('**Error**: No paper ID provided. Please provide a ArXiv ID or DOI.')