jskim commited on
Commit
6004e76
1 Parent(s): 77829a1

adding knob to control number of highlights. replacing the main model with specter2. using specter2 for sentence-level highlight as well.

Browse files
Files changed (4) hide show
  1. app.py +103 -38
  2. details.html +2 -2
  3. input_format.py +1 -16
  4. score.py +52 -40
app.py CHANGED
@@ -24,15 +24,14 @@ doc_model.to(device)
24
  sent_model = doc_model # have the same model for document and sentence level
25
 
26
  # OR specify different model for sentence level
27
- # sent_model = SentenceTransformer('sentence-transformers/gtr-t5-base')
28
- # sent_model.to(device)
29
 
30
  def get_similar_paper(
31
  title_input,
32
  abstract_text_input,
33
  author_id_input,
34
  results={}, # this state variable will be updated and returned
35
- #progress=gr.Progress()
36
  ):
37
  progress = gr.Progress()
38
  num_papers_show = 10 # number of top papers to show from the reviewer
@@ -82,7 +81,7 @@ def get_similar_paper(
82
  print('obtaining highlights..')
83
  start = time.time()
84
  input_sentences = sent_tokenize(abstract_text_input)
85
- num_sents = len(input_sentences)
86
 
87
  for aa, (tt, ab, ds, url) in enumerate(zip(titles, abstracts, doc_scores, paper_urls)):
88
  # Compute sent-level and phrase-level affinity scores for each papers
@@ -91,21 +90,26 @@ def get_similar_paper(
91
  tokenizer,
92
  abstract_text_input,
93
  ab,
94
- K=2 # top two sentences from the candidate
 
95
  )
 
96
 
97
  # get scores for each word in the format for Gradio Interpretation component
98
  word_scores = dict()
99
- for i in range(num_sents):
100
- ww, ss = remove_spaces(info['all_words'], info[i]['scores'])
101
- word_scores[str(i)] = {
102
- "original": ab,
103
- "interpretation": list(zip(ww, ss))
104
- }
 
 
105
 
106
  results[display_title[aa]] = {
107
  'title': tt,
108
  'abstract': ab,
 
109
  'doc_score': '%0.3f'%ds,
110
  'source_sentences': input_sentences,
111
  'highlight': word_scores,
@@ -117,6 +121,9 @@ def get_similar_paper(
117
  highlight_time = end - start
118
  print('done in [%0.2f] seconds'%(highlight_time))
119
 
 
 
 
120
  ## Set up output elements
121
 
122
  # first the list of top papers, sentences to select from, paper_title, affinity
@@ -180,13 +187,13 @@ def get_similar_paper(
180
  assert(len(out) == (top_num_info_show * 5 + 2) * top_papers_show + 5)
181
 
182
  out += [gr.update(value="""
183
- <h3>Top three relevant papers by the reviewer <a href="%s" target="_blank">%s</a></h3>
184
-
185
- For each paper, two sentence pairs (one from the submission, one from the paper) with the highest relevance scores are shown.
186
-
187
- **<span style="color:black;background-color:#65B5E3;">Blue highlights</span>**: phrases that appear in both sentences.
188
- """%(author_id_input, results['name']),
189
- visible=True)] # result 1 description
190
 
191
  out += [gr.update(visible=True), gr.update(visible=True)] # demarcation line between results
192
 
@@ -195,11 +202,14 @@ def get_similar_paper(
195
 
196
  # result 2 description
197
  desc = """
198
- ##### Click a paper by %s on the left (sorted by affinity scores), and a sentence from the submission on the right, to see which parts the paper are relevant.
199
  """%results['name']
200
  out += [gr.update(value=desc)]
201
 
202
- # add the search results to pass on to the Gradio State varaible
 
 
 
203
  out += [results]
204
 
205
  return tuple(out)
@@ -213,6 +223,7 @@ def show_more(info):
213
  gr.update(visible=True), # title row
214
  gr.update(visible=True), # affinity row
215
  gr.update(visible=True), # highlight legend
 
216
  gr.update(visible=True), # highlight abstract
217
  )
218
 
@@ -226,33 +237,59 @@ def update_name(author_id_input):
226
 
227
  return gr.update(value=name)
228
 
229
- def change_sentence(selected_papers_radio, source_sent_choice, info={}):
 
 
 
 
 
230
  # change the output highlight based on the sentence selected from the submission
231
  if len(info.keys()) != 0: # if the info is not empty
232
  source_sents = info[selected_papers_radio]['source_sentences']
233
  highlights = info[selected_papers_radio]['highlight']
234
- for i, s in enumerate(source_sents):
235
- if source_sent_choice == s:
236
- return highlights[str(i)]
237
  else:
238
  return
239
 
240
- def change_paper(selected_papers_radio, source_sent_choice, info={}):
 
 
 
 
 
241
  if len(info.keys()) != 0: # if the info is not empty
242
  source_sents = info[selected_papers_radio]['source_sentences']
243
  title = info[selected_papers_radio]['title']
 
244
  abstract = info[selected_papers_radio]['abstract']
245
  aff_score = info[selected_papers_radio]['doc_score']
246
  highlights = info[selected_papers_radio]['highlight']
247
  url = info[selected_papers_radio]['url']
248
  title_out = """<a href="%s" target="_blank"><h5>%s</h5></a>"""%(url, title)
249
  aff_score_out = '##### Affinity Score: %s'%aff_score
250
- for i, s in enumerate(source_sents):
251
- if source_sent_choice == s:
252
- return title_out, abstract, aff_score_out, highlights[str(i)]
 
 
253
  else:
254
  return
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  with gr.Blocks(css='style.css') as demo:
257
  info = gr.State({}) # cached search results as a State variable shared throughout
258
 
@@ -281,6 +318,7 @@ R2P2 provides more information about each reviewer. It searches for the **most r
281
  gr.HTML(more_details_instruction)
282
  gr.Markdown("""---""")
283
 
 
284
  ### INPUT
285
  with gr.Row() as input_row:
286
  with gr.Column(scale=3):
@@ -311,7 +349,6 @@ R2P2 provides more information about each reviewer. It searches for the **most r
311
 
312
  with gr.Row():
313
  search_status = gr.Textbox(label='Search Status', interactive=False, visible=False)
314
-
315
 
316
 
317
  ### OVERVIEW
@@ -416,13 +453,11 @@ R2P2 provides more information about each reviewer. It searches for the **most r
416
 
417
  # Highlight description
418
  hl_desc = """
419
- **<span style="color:black;background-color:#DB7262;">Red</span>**: sentences simiar to the selected sentence from submission. Darker = more similar.
420
-
421
- **<span style="color:black;background-color:#65B5E3;">Blue</span>**: phrases that appear in both sentences.
422
-
423
- ---
424
- """
425
- # TODO allow users to change the number of highlights to show?
426
  # show multiple papers in radio check box to select from
427
  paper_abstract = gr.Textbox(label='Abstract', interactive=False, visible=False)
428
  with gr.Row():
@@ -442,17 +477,29 @@ R2P2 provides more information about each reviewer. It searches for the **most r
442
  with gr.Column(scale=3):
443
  # selected paper and highlight
444
  with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
445
  highlight_legend = gr.Markdown(value=hl_desc, visible=False)
446
  with gr.Row(visible=False) as title_row:
 
447
  paper_title = gr.Markdown(value='')
448
  with gr.Row(visible=False) as aff_row:
 
449
  affinity = gr.Markdown(value='')
450
  with gr.Row(visible=False) as hl_row:
451
  # highlighted text from paper
452
  highlight = gr.components.Interpretation(paper_abstract)
453
 
454
 
455
-
456
  ### EVENT LISTENERS
457
 
458
  compute_btn.click(
@@ -517,6 +564,7 @@ R2P2 provides more information about each reviewer. It searches for the **most r
517
  demarc2,
518
  search_status,
519
  result2_desc,
 
520
  info,
521
  ],
522
  show_progress=True,
@@ -534,7 +582,8 @@ R2P2 provides more information about each reviewer. It searches for the **most r
534
  title_row,
535
  aff_row,
536
  highlight_legend,
537
- hl_row
 
538
  ]
539
  )
540
 
@@ -544,6 +593,7 @@ R2P2 provides more information about each reviewer. It searches for the **most r
544
  inputs=[
545
  selected_papers_radio,
546
  source_sentences,
 
547
  info
548
  ],
549
  outputs=highlight
@@ -555,12 +605,27 @@ R2P2 provides more information about each reviewer. It searches for the **most r
555
  inputs=[
556
  selected_papers_radio,
557
  source_sentences,
 
558
  info,
559
  ],
560
  outputs= [
561
  paper_title,
562
  paper_abstract,
563
  affinity,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
  highlight
565
  ]
566
  )
 
24
  sent_model = doc_model # have the same model for document and sentence level
25
 
26
  # OR specify different model for sentence level
27
+ #sent_model = SentenceTransformer('sentence-transformers/gtr-t5-base')
28
+ #sent_model.to(device)
29
 
30
  def get_similar_paper(
31
  title_input,
32
  abstract_text_input,
33
  author_id_input,
34
  results={}, # this state variable will be updated and returned
 
35
  ):
36
  progress = gr.Progress()
37
  num_papers_show = 10 # number of top papers to show from the reviewer
 
81
  print('obtaining highlights..')
82
  start = time.time()
83
  input_sentences = sent_tokenize(abstract_text_input)
84
+ num_input_sents = len(input_sentences)
85
 
86
  for aa, (tt, ab, ds, url) in enumerate(zip(titles, abstracts, doc_scores, paper_urls)):
87
  # Compute sent-level and phrase-level affinity scores for each papers
 
90
  tokenizer,
91
  abstract_text_input,
92
  ab,
93
+ K=None, # top two sentences from the candidate
94
+ top_pair_num=3, # top five sentence pairs to show upfront
95
  )
96
+ num_cand_sents = sent_ids.shape[1]
97
 
98
  # get scores for each word in the format for Gradio Interpretation component
99
  word_scores = dict()
100
+ for i in range(num_input_sents):
101
+ word_scores[str(i)] = dict()
102
+ for j in range(1, num_cand_sents+1):
103
+ ww, ss = remove_spaces(info['all_words'], info[i][j]['scores'])
104
+ word_scores[str(i)][str(j)] = {
105
+ "original": ab,
106
+ "interpretation": list(zip(ww, ss))
107
+ }
108
 
109
  results[display_title[aa]] = {
110
  'title': tt,
111
  'abstract': ab,
112
+ 'num_cand_sents': num_cand_sents,
113
  'doc_score': '%0.3f'%ds,
114
  'source_sentences': input_sentences,
115
  'highlight': word_scores,
 
121
  highlight_time = end - start
122
  print('done in [%0.2f] seconds'%(highlight_time))
123
 
124
+ # debugging only
125
+ pickle.dump(results, open('info.pkl', 'wb'))
126
+
127
  ## Set up output elements
128
 
129
  # first the list of top papers, sentences to select from, paper_title, affinity
 
187
  assert(len(out) == (top_num_info_show * 5 + 2) * top_papers_show + 5)
188
 
189
  out += [gr.update(value="""
190
+ <h3>Top three relevant papers by the reviewer <a href="%s" target="_blank">%s</a></h3>
191
+
192
+ For each paper, two sentence pairs (one from the submission, one from the paper) with the highest relevance scores are shown.
193
+
194
+ **<span style="color:black;background-color:#65B5E3;">Blue highlights</span>**: phrases that appear in both sentences.
195
+ """%(author_id_input, results['name']),
196
+ visible=True)] # result 1 description
197
 
198
  out += [gr.update(visible=True), gr.update(visible=True)] # demarcation line between results
199
 
 
202
 
203
  # result 2 description
204
  desc = """
205
+ ##### Click a paper by %s on the left (sorted by affinity scores), and a sentence from the submission on the right, to see which parts of the paper are relevant.
206
  """%results['name']
207
  out += [gr.update(value=desc)]
208
 
209
+ # slider to control the number of highlights
210
+ out += [gr.update(value=1, maximum=len(sent_tokenize(abstracts[0])))]
211
+
212
+ # finally add the search results to pass on to the Gradio State varaible
213
  out += [results]
214
 
215
  return tuple(out)
 
223
  gr.update(visible=True), # title row
224
  gr.update(visible=True), # affinity row
225
  gr.update(visible=True), # highlight legend
226
+ gr.update(visible=True), # highlight slider
227
  gr.update(visible=True), # highlight abstract
228
  )
229
 
 
237
 
238
  return gr.update(value=name)
239
 
240
+ def change_sentence(
241
+ selected_papers_radio,
242
+ source_sent_choice,
243
+ highlight_slider,
244
+ info={}
245
+ ):
246
  # change the output highlight based on the sentence selected from the submission
247
  if len(info.keys()) != 0: # if the info is not empty
248
  source_sents = info[selected_papers_radio]['source_sentences']
249
  highlights = info[selected_papers_radio]['highlight']
250
+ idx = source_sents.index(source_sent_choice)
251
+ return highlights[str(idx)][str(highlight_slider)]
 
252
  else:
253
  return
254
 
255
+ def change_paper(
256
+ selected_papers_radio,
257
+ source_sent_choice,
258
+ highlight_slider,
259
+ info={}
260
+ ):
261
  if len(info.keys()) != 0: # if the info is not empty
262
  source_sents = info[selected_papers_radio]['source_sentences']
263
  title = info[selected_papers_radio]['title']
264
+ num_sents = info[selected_papers_radio]['num_cand_sents']
265
  abstract = info[selected_papers_radio]['abstract']
266
  aff_score = info[selected_papers_radio]['doc_score']
267
  highlights = info[selected_papers_radio]['highlight']
268
  url = info[selected_papers_radio]['url']
269
  title_out = """<a href="%s" target="_blank"><h5>%s</h5></a>"""%(url, title)
270
  aff_score_out = '##### Affinity Score: %s'%aff_score
271
+ idx = source_sents.index(source_sent_choice)
272
+ if highlight_slider <= num_sents:
273
+ return title_out, abstract, aff_score_out, highlights[str(idx)][str(highlight_slider)], gr.update(value=highlight_slider, maximum=num_sents)
274
+ else: # if the slider is set to more than the current number of sentences, show the max number of highlights
275
+ return title_out, abstract, aff_score_out, highlights[str(idx)][str(num_sents)], gr.update(value=num_sents, maximum=num_sents)
276
  else:
277
  return
278
 
279
+ def change_num_highlight(
280
+ selected_papers_radio,
281
+ source_sent_choice,
282
+ highlight_slider,
283
+ info={}
284
+ ):
285
+ if len(info.keys()) != 0: # if the info is not empty
286
+ source_sents = info[selected_papers_radio]['source_sentences']
287
+ highlights = info[selected_papers_radio]['highlight']
288
+ idx = source_sents.index(source_sent_choice)
289
+ return highlights[str(idx)][str(highlight_slider)]
290
+ else:
291
+ return
292
+
293
  with gr.Blocks(css='style.css') as demo:
294
  info = gr.State({}) # cached search results as a State variable shared throughout
295
 
 
318
  gr.HTML(more_details_instruction)
319
  gr.Markdown("""---""")
320
 
321
+
322
  ### INPUT
323
  with gr.Row() as input_row:
324
  with gr.Column(scale=3):
 
349
 
350
  with gr.Row():
351
  search_status = gr.Textbox(label='Search Status', interactive=False, visible=False)
 
352
 
353
 
354
  ### OVERVIEW
 
453
 
454
  # Highlight description
455
  hl_desc = """
456
+ <font size="2">**<span style="color:black;background-color:#DB7262;">Red</span>**: sentences simiar to the selected sentence from submission. Darker = more similar.</font>
457
+
458
+ <font size="2">**<span style="color:black;background-color:#65B5E3;">Blue</span>**: phrases that appear in both sentences.</font>
459
+ """
460
+ #---"""
 
 
461
  # show multiple papers in radio check box to select from
462
  paper_abstract = gr.Textbox(label='Abstract', interactive=False, visible=False)
463
  with gr.Row():
 
477
  with gr.Column(scale=3):
478
  # selected paper and highlight
479
  with gr.Row():
480
+ # slider for highlight amount
481
+ highlight_slider = gr.Slider(
482
+ label='Number of Highlighted Sentences',
483
+ minimum=1,
484
+ maximum=15,
485
+ step=1,
486
+ value=2,
487
+ visible=False
488
+ )
489
+ with gr.Row():
490
+ # highlight legend
491
  highlight_legend = gr.Markdown(value=hl_desc, visible=False)
492
  with gr.Row(visible=False) as title_row:
493
+ # selected paper title
494
  paper_title = gr.Markdown(value='')
495
  with gr.Row(visible=False) as aff_row:
496
+ # selected paper's affinity score
497
  affinity = gr.Markdown(value='')
498
  with gr.Row(visible=False) as hl_row:
499
  # highlighted text from paper
500
  highlight = gr.components.Interpretation(paper_abstract)
501
 
502
 
 
503
  ### EVENT LISTENERS
504
 
505
  compute_btn.click(
 
564
  demarc2,
565
  search_status,
566
  result2_desc,
567
+ highlight_slider,
568
  info,
569
  ],
570
  show_progress=True,
 
582
  title_row,
583
  aff_row,
584
  highlight_legend,
585
+ highlight_slider,
586
+ hl_row,
587
  ]
588
  )
589
 
 
593
  inputs=[
594
  selected_papers_radio,
595
  source_sentences,
596
+ highlight_slider,
597
  info
598
  ],
599
  outputs=highlight
 
605
  inputs=[
606
  selected_papers_radio,
607
  source_sentences,
608
+ highlight_slider,
609
  info,
610
  ],
611
  outputs= [
612
  paper_title,
613
  paper_abstract,
614
  affinity,
615
+ highlight,
616
+ highlight_slider
617
+ ]
618
+ )
619
+
620
+ highlight_slider.change(
621
+ fn=change_num_highlight,
622
+ inputs=[
623
+ selected_papers_radio,
624
+ source_sentences,
625
+ highlight_slider,
626
+ info
627
+ ],
628
+ outputs=[
629
  highlight
630
  ]
631
  )
details.html CHANGED
@@ -9,8 +9,8 @@ The tool is developed by <a href="https://wnstlr.github.io", target="_blank">Joo
9
  <h1>What Happens Behind the Scenes</h1>
10
  <ul>
11
  <li> The tool retrieves the reviewer's previous publications using <a href="https://www.semanticscholar.org/product/api", target="_blank">Semantic Scholar API</a>.</li>
12
- <li> The tool computes the affinity score between the submission abstract and each paper's abstract, using text representations from a <a href="https://github.com/allenai/specter/tree/master/specter", target="_blank">language model fine-tuned on academic papers</a>.</li>
13
- <li> The tool computes pairwise sentence relevance scores between the submission abstract and the reviewer paper's abstract, using text representations from a <a href="https://huggingface.co/sentence-transformers/gtr-t5-base", target="_blank">sentence-level langauge model</a>.</li>
14
  <li> The tool highlights overlapping words (nouns) between setence pairs using <a href="https://www.nltk.org/book/ch05.html", target="_blank">POS tagging</a>.</li>
15
  </ul>
16
 
 
9
  <h1>What Happens Behind the Scenes</h1>
10
  <ul>
11
  <li> The tool retrieves the reviewer's previous publications using <a href="https://www.semanticscholar.org/product/api", target="_blank">Semantic Scholar API</a>.</li>
12
+ <li> The tool computes the affinity score between the submission abstract and each paper's abstract, using text representations from a <a href="https://huggingface.co/allenai/specter2", target="_blank">language model fine-tuned on academic papers</a>.</li>
13
+ <li> The tool then computes pairwise sentence relevance scores between the submission abstract and the reviewer paper's abstract, using text representations from <a href="https://huggingface.co/allenai/specter2", target="_blank">the same model</a>.</li>
14
  <li> The tool highlights overlapping words (nouns) between setence pairs using <a href="https://www.nltk.org/book/ch05.html", target="_blank">POS tagging</a>.</li>
15
  </ul>
16
 
input_format.py CHANGED
@@ -81,19 +81,4 @@ def get_text_from_author_id(author_id, max_count=150):
81
  papers = data['papers'][:max_count]
82
  name = data['name']
83
 
84
- return name, papers
85
-
86
- ## TODO Preprocess Extracted Texts from PDFs
87
- # Get a portion of the text for actual task
88
-
89
- def get_title(text):
90
- pass
91
-
92
- def get_abstract(text):
93
- pass
94
-
95
- def get_introduction(text):
96
- pass
97
-
98
- def get_conclusion(text):
99
- pass
 
81
  papers = data['papers'][:max_count]
82
  name = data['name']
83
 
84
+ return name, papers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
score.py CHANGED
@@ -40,12 +40,9 @@ def get_top_k(score_mat, K=3):
40
  """
41
  Pick top K sentences to show
42
  """
43
- idx = torch.argsort(-score_mat)
44
- picked_sent = idx[:,:K]
45
- picked_scores = torch.vstack(
46
- [score_mat[i,picked_sent[i]] for i in range(picked_sent.shape[0])]
47
- )
48
-
49
  return picked_sent, picked_scores
50
 
51
  def get_words(sent):
@@ -57,7 +54,6 @@ def get_words(sent):
57
  sent_start_id = [] # keep track of the word index where the new sentence starts
58
  counter = 0
59
  for x in sent:
60
- #w = x.split()
61
  w = word_tokenize(x)
62
  nw = len(w)
63
  counter += nw
@@ -180,11 +176,21 @@ def remove_spaces(words, attrs):
180
  assert(len(word_out) == len(attr_out))
181
  return word_out, attr_out
182
 
 
 
 
 
 
 
 
 
 
183
  def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scores):
184
  """
185
  Mark the words that are highlighted, both by in terms of sentence and phrase
186
  """
187
  num_query_sent = sent_ids.shape[0]
 
188
  num_words = len(all_words)
189
 
190
  output = dict()
@@ -193,53 +199,59 @@ def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scor
193
 
194
  # for each query sentence, mark the highlight information
195
  for i in range(num_query_sent):
196
- query_words = word_tokenize(query_sents[i])
197
- is_selected_sent = np.zeros(num_words)
198
- is_selected_phrase = np.zeros(num_words)
199
- word_scores = np.zeros(num_words)
200
-
201
- # for each selected sentences from the candidate, compile information
202
- for sid, sscore in zip(sent_ids[i], sent_scores[i]):
203
- #print(len(sent_start_id), sid, sid+1)
204
- if sid+1 < len(sent_start_id):
205
- sent_range = (sent_start_id[sid], sent_start_id[sid+1])
206
- is_selected_sent[sent_range[0]:sent_range[1]] = 1
207
- word_scores[sent_range[0]:sent_range[1]] = sscore
208
- _, is_selected_phrase[sent_range[0]:sent_range[1]] = \
209
- get_match_phrase(query_words, all_words[sent_range[0]:sent_range[1]])
210
- else:
211
- is_selected_sent[sent_start_id[sid]:] = 1
212
- word_scores[sent_start_id[sid]:] = sscore
213
- _, is_selected_phrase[sent_start_id[sid]:] = \
214
- get_match_phrase(query_words, all_words[sent_start_id[sid]:])
215
-
216
- # update selected phrase scores (-1 meaning a different color in gradio)
217
- word_scores[is_selected_sent+is_selected_phrase==2] = -0.5
218
 
219
- output[i] = {
220
- 'is_selected_sent': is_selected_sent,
221
- 'is_selected_phrase': is_selected_phrase,
222
- 'scores': word_scores
223
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  return output
226
 
227
- def get_highlight_info(model, tokenizer, text1, text2, K=None):
228
  """
229
  Get highlight information from two texts
230
  """
231
  sent1 = sent_tokenize(text1) # query
232
  sent2 = sent_tokenize(text2) # candidate
233
- if K is None: # if K is not set, select based on the length of the candidate
234
- K = int(len(sent2) / 3)
235
  score_mat = compute_sentencewise_scores(model, sent1, sent2, tokenizer=tokenizer)
236
-
 
 
 
237
  sent_ids, sent_scores = get_top_k(score_mat, K=K)
238
  words2, all_words2, sent_start_id2 = get_words(sent2)
239
  info = mark_words(sent1, words2, all_words2, sent_start_id2, sent_ids, sent_scores)
240
 
241
- # get top sentence pairs from the query and candidate (score, index_pair)
242
- top_pair_num = 5
243
  top_pairs = []
244
  ii = np.unravel_index(np.argsort(np.array(sent_scores).ravel())[-top_pair_num:], sent_scores.shape)
245
  for i, j in zip(ii[0][::-1], ii[1][::-1]):
 
40
  """
41
  Pick top K sentences to show
42
  """
43
+ picked_scores, picked_sent = torch.sort(-score_mat, axis=1)
44
+ picked_sent = picked_sent[:,:K]
45
+ picked_scores = -picked_scores[:,:K]
 
 
 
46
  return picked_sent, picked_scores
47
 
48
  def get_words(sent):
 
54
  sent_start_id = [] # keep track of the word index where the new sentence starts
55
  counter = 0
56
  for x in sent:
 
57
  w = word_tokenize(x)
58
  nw = len(w)
59
  counter += nw
 
176
  assert(len(word_out) == len(attr_out))
177
  return word_out, attr_out
178
 
179
+ def scale_scores(arr, vmin=0.1, vmax=1):
180
+ # rescale positive and negative attributions to be between vmin and vmax.
181
+ # while keeping 0 at 0.
182
+ pos_max, pos_min = np.max(arr[arr > 0]), np.min(arr[arr > 0])
183
+ out = (arr - pos_min) / (pos_max - pos_min) * (vmax - vmin) + vmin
184
+ idx = np.where(arr == 0.0)[0]
185
+ out[idx] = 0.0
186
+ return out
187
+
188
  def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scores):
189
  """
190
  Mark the words that are highlighted, both by in terms of sentence and phrase
191
  """
192
  num_query_sent = sent_ids.shape[0]
193
+ num_cand_sent = sent_ids.shape[1]
194
  num_words = len(all_words)
195
 
196
  output = dict()
 
199
 
200
  # for each query sentence, mark the highlight information
201
  for i in range(num_query_sent):
202
+ output[i] = dict()
203
+ for j in range(1, num_cand_sent+1): # for each number of selected sentences from candidate
204
+ query_words = word_tokenize(query_sents[i])
205
+ is_selected_sent = np.zeros(num_words)
206
+ is_selected_phrase = np.zeros(num_words)
207
+ word_scores = np.zeros(num_words)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
+ # for each selected sentences from the candidate, compile information
210
+ for sid, sscore in zip(sent_ids[i][:j], sent_scores[i][:j]):
211
+ #print(len(sent_start_id), sid, sid+1)
212
+ if sid+1 < len(sent_start_id):
213
+ sent_range = (sent_start_id[sid], sent_start_id[sid+1])
214
+ is_selected_sent[sent_range[0]:sent_range[1]] = 1
215
+ word_scores[sent_range[0]:sent_range[1]] = sscore
216
+ _, is_selected_phrase[sent_range[0]:sent_range[1]] = \
217
+ get_match_phrase(query_words, all_words[sent_range[0]:sent_range[1]])
218
+ else:
219
+ is_selected_sent[sent_start_id[sid]:] = 1
220
+ word_scores[sent_start_id[sid]:] = sscore
221
+ _, is_selected_phrase[sent_start_id[sid]:] = \
222
+ get_match_phrase(query_words, all_words[sent_start_id[sid]:])
223
+
224
+ # scale the word_scores: maximum value gets the darkest, minimum value gets the lightest color
225
+ if j > 1:
226
+ word_scores = scale_scores(word_scores)
227
+
228
+ # update selected phrase scores (-1 meaning a different color in gradio)
229
+ word_scores[is_selected_sent+is_selected_phrase==2] = -0.5
230
+
231
+ output[i][j] = {
232
+ 'is_selected_sent': is_selected_sent,
233
+ 'is_selected_phrase': is_selected_phrase,
234
+ 'scores': word_scores
235
+ }
236
 
237
  return output
238
 
239
+ def get_highlight_info(model, tokenizer, text1, text2, K=None, top_pair_num=5):
240
  """
241
  Get highlight information from two texts
242
  """
243
  sent1 = sent_tokenize(text1) # query
244
  sent2 = sent_tokenize(text2) # candidate
 
 
245
  score_mat = compute_sentencewise_scores(model, sent1, sent2, tokenizer=tokenizer)
246
+
247
+ if K is None: # if K is not set, get all information
248
+ K = score_mat.shape[1]
249
+
250
  sent_ids, sent_scores = get_top_k(score_mat, K=K)
251
  words2, all_words2, sent_start_id2 = get_words(sent2)
252
  info = mark_words(sent1, words2, all_words2, sent_start_id2, sent_ids, sent_scores)
253
 
254
+ # get top sentence pairs from the query and candidate (score, index_pair) to show upfront
 
255
  top_pairs = []
256
  ii = np.unravel_index(np.argsort(np.array(sent_scores).ravel())[-top_pair_num:], sent_scores.shape)
257
  for i, j in zip(ii[0][::-1], ii[1][::-1]):