Spaces:
Runtime error
Runtime error
adding knob to control number of highlights. replacing the main model with specter2. using specter2 for sentence-level highlight as well.
Browse files- app.py +103 -38
- details.html +2 -2
- input_format.py +1 -16
- score.py +52 -40
app.py
CHANGED
@@ -24,15 +24,14 @@ doc_model.to(device)
|
|
24 |
sent_model = doc_model # have the same model for document and sentence level
|
25 |
|
26 |
# OR specify different model for sentence level
|
27 |
-
#
|
28 |
-
#
|
29 |
|
30 |
def get_similar_paper(
|
31 |
title_input,
|
32 |
abstract_text_input,
|
33 |
author_id_input,
|
34 |
results={}, # this state variable will be updated and returned
|
35 |
-
#progress=gr.Progress()
|
36 |
):
|
37 |
progress = gr.Progress()
|
38 |
num_papers_show = 10 # number of top papers to show from the reviewer
|
@@ -82,7 +81,7 @@ def get_similar_paper(
|
|
82 |
print('obtaining highlights..')
|
83 |
start = time.time()
|
84 |
input_sentences = sent_tokenize(abstract_text_input)
|
85 |
-
|
86 |
|
87 |
for aa, (tt, ab, ds, url) in enumerate(zip(titles, abstracts, doc_scores, paper_urls)):
|
88 |
# Compute sent-level and phrase-level affinity scores for each papers
|
@@ -91,21 +90,26 @@ def get_similar_paper(
|
|
91 |
tokenizer,
|
92 |
abstract_text_input,
|
93 |
ab,
|
94 |
-
K=
|
|
|
95 |
)
|
|
|
96 |
|
97 |
# get scores for each word in the format for Gradio Interpretation component
|
98 |
word_scores = dict()
|
99 |
-
for i in range(
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
105 |
|
106 |
results[display_title[aa]] = {
|
107 |
'title': tt,
|
108 |
'abstract': ab,
|
|
|
109 |
'doc_score': '%0.3f'%ds,
|
110 |
'source_sentences': input_sentences,
|
111 |
'highlight': word_scores,
|
@@ -117,6 +121,9 @@ def get_similar_paper(
|
|
117 |
highlight_time = end - start
|
118 |
print('done in [%0.2f] seconds'%(highlight_time))
|
119 |
|
|
|
|
|
|
|
120 |
## Set up output elements
|
121 |
|
122 |
# first the list of top papers, sentences to select from, paper_title, affinity
|
@@ -180,13 +187,13 @@ def get_similar_paper(
|
|
180 |
assert(len(out) == (top_num_info_show * 5 + 2) * top_papers_show + 5)
|
181 |
|
182 |
out += [gr.update(value="""
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
|
191 |
out += [gr.update(visible=True), gr.update(visible=True)] # demarcation line between results
|
192 |
|
@@ -195,11 +202,14 @@ def get_similar_paper(
|
|
195 |
|
196 |
# result 2 description
|
197 |
desc = """
|
198 |
-
##### Click a paper by %s on the left (sorted by affinity scores), and a sentence from the submission on the right, to see which parts the paper are relevant.
|
199 |
"""%results['name']
|
200 |
out += [gr.update(value=desc)]
|
201 |
|
202 |
-
#
|
|
|
|
|
|
|
203 |
out += [results]
|
204 |
|
205 |
return tuple(out)
|
@@ -213,6 +223,7 @@ def show_more(info):
|
|
213 |
gr.update(visible=True), # title row
|
214 |
gr.update(visible=True), # affinity row
|
215 |
gr.update(visible=True), # highlight legend
|
|
|
216 |
gr.update(visible=True), # highlight abstract
|
217 |
)
|
218 |
|
@@ -226,33 +237,59 @@ def update_name(author_id_input):
|
|
226 |
|
227 |
return gr.update(value=name)
|
228 |
|
229 |
-
def change_sentence(
|
|
|
|
|
|
|
|
|
|
|
230 |
# change the output highlight based on the sentence selected from the submission
|
231 |
if len(info.keys()) != 0: # if the info is not empty
|
232 |
source_sents = info[selected_papers_radio]['source_sentences']
|
233 |
highlights = info[selected_papers_radio]['highlight']
|
234 |
-
|
235 |
-
|
236 |
-
return highlights[str(i)]
|
237 |
else:
|
238 |
return
|
239 |
|
240 |
-
def change_paper(
|
|
|
|
|
|
|
|
|
|
|
241 |
if len(info.keys()) != 0: # if the info is not empty
|
242 |
source_sents = info[selected_papers_radio]['source_sentences']
|
243 |
title = info[selected_papers_radio]['title']
|
|
|
244 |
abstract = info[selected_papers_radio]['abstract']
|
245 |
aff_score = info[selected_papers_radio]['doc_score']
|
246 |
highlights = info[selected_papers_radio]['highlight']
|
247 |
url = info[selected_papers_radio]['url']
|
248 |
title_out = """<a href="%s" target="_blank"><h5>%s</h5></a>"""%(url, title)
|
249 |
aff_score_out = '##### Affinity Score: %s'%aff_score
|
250 |
-
|
251 |
-
|
252 |
-
|
|
|
|
|
253 |
else:
|
254 |
return
|
255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
with gr.Blocks(css='style.css') as demo:
|
257 |
info = gr.State({}) # cached search results as a State variable shared throughout
|
258 |
|
@@ -281,6 +318,7 @@ R2P2 provides more information about each reviewer. It searches for the **most r
|
|
281 |
gr.HTML(more_details_instruction)
|
282 |
gr.Markdown("""---""")
|
283 |
|
|
|
284 |
### INPUT
|
285 |
with gr.Row() as input_row:
|
286 |
with gr.Column(scale=3):
|
@@ -311,7 +349,6 @@ R2P2 provides more information about each reviewer. It searches for the **most r
|
|
311 |
|
312 |
with gr.Row():
|
313 |
search_status = gr.Textbox(label='Search Status', interactive=False, visible=False)
|
314 |
-
|
315 |
|
316 |
|
317 |
### OVERVIEW
|
@@ -416,13 +453,11 @@ R2P2 provides more information about each reviewer. It searches for the **most r
|
|
416 |
|
417 |
# Highlight description
|
418 |
hl_desc = """
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
"""
|
425 |
-
# TODO allow users to change the number of highlights to show?
|
426 |
# show multiple papers in radio check box to select from
|
427 |
paper_abstract = gr.Textbox(label='Abstract', interactive=False, visible=False)
|
428 |
with gr.Row():
|
@@ -442,17 +477,29 @@ R2P2 provides more information about each reviewer. It searches for the **most r
|
|
442 |
with gr.Column(scale=3):
|
443 |
# selected paper and highlight
|
444 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
highlight_legend = gr.Markdown(value=hl_desc, visible=False)
|
446 |
with gr.Row(visible=False) as title_row:
|
|
|
447 |
paper_title = gr.Markdown(value='')
|
448 |
with gr.Row(visible=False) as aff_row:
|
|
|
449 |
affinity = gr.Markdown(value='')
|
450 |
with gr.Row(visible=False) as hl_row:
|
451 |
# highlighted text from paper
|
452 |
highlight = gr.components.Interpretation(paper_abstract)
|
453 |
|
454 |
|
455 |
-
|
456 |
### EVENT LISTENERS
|
457 |
|
458 |
compute_btn.click(
|
@@ -517,6 +564,7 @@ R2P2 provides more information about each reviewer. It searches for the **most r
|
|
517 |
demarc2,
|
518 |
search_status,
|
519 |
result2_desc,
|
|
|
520 |
info,
|
521 |
],
|
522 |
show_progress=True,
|
@@ -534,7 +582,8 @@ R2P2 provides more information about each reviewer. It searches for the **most r
|
|
534 |
title_row,
|
535 |
aff_row,
|
536 |
highlight_legend,
|
537 |
-
|
|
|
538 |
]
|
539 |
)
|
540 |
|
@@ -544,6 +593,7 @@ R2P2 provides more information about each reviewer. It searches for the **most r
|
|
544 |
inputs=[
|
545 |
selected_papers_radio,
|
546 |
source_sentences,
|
|
|
547 |
info
|
548 |
],
|
549 |
outputs=highlight
|
@@ -555,12 +605,27 @@ R2P2 provides more information about each reviewer. It searches for the **most r
|
|
555 |
inputs=[
|
556 |
selected_papers_radio,
|
557 |
source_sentences,
|
|
|
558 |
info,
|
559 |
],
|
560 |
outputs= [
|
561 |
paper_title,
|
562 |
paper_abstract,
|
563 |
affinity,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
564 |
highlight
|
565 |
]
|
566 |
)
|
|
|
24 |
sent_model = doc_model # have the same model for document and sentence level
|
25 |
|
26 |
# OR specify different model for sentence level
|
27 |
+
#sent_model = SentenceTransformer('sentence-transformers/gtr-t5-base')
|
28 |
+
#sent_model.to(device)
|
29 |
|
30 |
def get_similar_paper(
|
31 |
title_input,
|
32 |
abstract_text_input,
|
33 |
author_id_input,
|
34 |
results={}, # this state variable will be updated and returned
|
|
|
35 |
):
|
36 |
progress = gr.Progress()
|
37 |
num_papers_show = 10 # number of top papers to show from the reviewer
|
|
|
81 |
print('obtaining highlights..')
|
82 |
start = time.time()
|
83 |
input_sentences = sent_tokenize(abstract_text_input)
|
84 |
+
num_input_sents = len(input_sentences)
|
85 |
|
86 |
for aa, (tt, ab, ds, url) in enumerate(zip(titles, abstracts, doc_scores, paper_urls)):
|
87 |
# Compute sent-level and phrase-level affinity scores for each papers
|
|
|
90 |
tokenizer,
|
91 |
abstract_text_input,
|
92 |
ab,
|
93 |
+
K=None, # top two sentences from the candidate
|
94 |
+
top_pair_num=3, # top five sentence pairs to show upfront
|
95 |
)
|
96 |
+
num_cand_sents = sent_ids.shape[1]
|
97 |
|
98 |
# get scores for each word in the format for Gradio Interpretation component
|
99 |
word_scores = dict()
|
100 |
+
for i in range(num_input_sents):
|
101 |
+
word_scores[str(i)] = dict()
|
102 |
+
for j in range(1, num_cand_sents+1):
|
103 |
+
ww, ss = remove_spaces(info['all_words'], info[i][j]['scores'])
|
104 |
+
word_scores[str(i)][str(j)] = {
|
105 |
+
"original": ab,
|
106 |
+
"interpretation": list(zip(ww, ss))
|
107 |
+
}
|
108 |
|
109 |
results[display_title[aa]] = {
|
110 |
'title': tt,
|
111 |
'abstract': ab,
|
112 |
+
'num_cand_sents': num_cand_sents,
|
113 |
'doc_score': '%0.3f'%ds,
|
114 |
'source_sentences': input_sentences,
|
115 |
'highlight': word_scores,
|
|
|
121 |
highlight_time = end - start
|
122 |
print('done in [%0.2f] seconds'%(highlight_time))
|
123 |
|
124 |
+
# debugging only
|
125 |
+
pickle.dump(results, open('info.pkl', 'wb'))
|
126 |
+
|
127 |
## Set up output elements
|
128 |
|
129 |
# first the list of top papers, sentences to select from, paper_title, affinity
|
|
|
187 |
assert(len(out) == (top_num_info_show * 5 + 2) * top_papers_show + 5)
|
188 |
|
189 |
out += [gr.update(value="""
|
190 |
+
<h3>Top three relevant papers by the reviewer <a href="%s" target="_blank">%s</a></h3>
|
191 |
+
|
192 |
+
For each paper, two sentence pairs (one from the submission, one from the paper) with the highest relevance scores are shown.
|
193 |
+
|
194 |
+
**<span style="color:black;background-color:#65B5E3;">Blue highlights</span>**: phrases that appear in both sentences.
|
195 |
+
"""%(author_id_input, results['name']),
|
196 |
+
visible=True)] # result 1 description
|
197 |
|
198 |
out += [gr.update(visible=True), gr.update(visible=True)] # demarcation line between results
|
199 |
|
|
|
202 |
|
203 |
# result 2 description
|
204 |
desc = """
|
205 |
+
##### Click a paper by %s on the left (sorted by affinity scores), and a sentence from the submission on the right, to see which parts of the paper are relevant.
|
206 |
"""%results['name']
|
207 |
out += [gr.update(value=desc)]
|
208 |
|
209 |
+
# slider to control the number of highlights
|
210 |
+
out += [gr.update(value=1, maximum=len(sent_tokenize(abstracts[0])))]
|
211 |
+
|
212 |
+
# finally add the search results to pass on to the Gradio State varaible
|
213 |
out += [results]
|
214 |
|
215 |
return tuple(out)
|
|
|
223 |
gr.update(visible=True), # title row
|
224 |
gr.update(visible=True), # affinity row
|
225 |
gr.update(visible=True), # highlight legend
|
226 |
+
gr.update(visible=True), # highlight slider
|
227 |
gr.update(visible=True), # highlight abstract
|
228 |
)
|
229 |
|
|
|
237 |
|
238 |
return gr.update(value=name)
|
239 |
|
240 |
+
def change_sentence(
|
241 |
+
selected_papers_radio,
|
242 |
+
source_sent_choice,
|
243 |
+
highlight_slider,
|
244 |
+
info={}
|
245 |
+
):
|
246 |
# change the output highlight based on the sentence selected from the submission
|
247 |
if len(info.keys()) != 0: # if the info is not empty
|
248 |
source_sents = info[selected_papers_radio]['source_sentences']
|
249 |
highlights = info[selected_papers_radio]['highlight']
|
250 |
+
idx = source_sents.index(source_sent_choice)
|
251 |
+
return highlights[str(idx)][str(highlight_slider)]
|
|
|
252 |
else:
|
253 |
return
|
254 |
|
255 |
+
def change_paper(
|
256 |
+
selected_papers_radio,
|
257 |
+
source_sent_choice,
|
258 |
+
highlight_slider,
|
259 |
+
info={}
|
260 |
+
):
|
261 |
if len(info.keys()) != 0: # if the info is not empty
|
262 |
source_sents = info[selected_papers_radio]['source_sentences']
|
263 |
title = info[selected_papers_radio]['title']
|
264 |
+
num_sents = info[selected_papers_radio]['num_cand_sents']
|
265 |
abstract = info[selected_papers_radio]['abstract']
|
266 |
aff_score = info[selected_papers_radio]['doc_score']
|
267 |
highlights = info[selected_papers_radio]['highlight']
|
268 |
url = info[selected_papers_radio]['url']
|
269 |
title_out = """<a href="%s" target="_blank"><h5>%s</h5></a>"""%(url, title)
|
270 |
aff_score_out = '##### Affinity Score: %s'%aff_score
|
271 |
+
idx = source_sents.index(source_sent_choice)
|
272 |
+
if highlight_slider <= num_sents:
|
273 |
+
return title_out, abstract, aff_score_out, highlights[str(idx)][str(highlight_slider)], gr.update(value=highlight_slider, maximum=num_sents)
|
274 |
+
else: # if the slider is set to more than the current number of sentences, show the max number of highlights
|
275 |
+
return title_out, abstract, aff_score_out, highlights[str(idx)][str(num_sents)], gr.update(value=num_sents, maximum=num_sents)
|
276 |
else:
|
277 |
return
|
278 |
|
279 |
+
def change_num_highlight(
|
280 |
+
selected_papers_radio,
|
281 |
+
source_sent_choice,
|
282 |
+
highlight_slider,
|
283 |
+
info={}
|
284 |
+
):
|
285 |
+
if len(info.keys()) != 0: # if the info is not empty
|
286 |
+
source_sents = info[selected_papers_radio]['source_sentences']
|
287 |
+
highlights = info[selected_papers_radio]['highlight']
|
288 |
+
idx = source_sents.index(source_sent_choice)
|
289 |
+
return highlights[str(idx)][str(highlight_slider)]
|
290 |
+
else:
|
291 |
+
return
|
292 |
+
|
293 |
with gr.Blocks(css='style.css') as demo:
|
294 |
info = gr.State({}) # cached search results as a State variable shared throughout
|
295 |
|
|
|
318 |
gr.HTML(more_details_instruction)
|
319 |
gr.Markdown("""---""")
|
320 |
|
321 |
+
|
322 |
### INPUT
|
323 |
with gr.Row() as input_row:
|
324 |
with gr.Column(scale=3):
|
|
|
349 |
|
350 |
with gr.Row():
|
351 |
search_status = gr.Textbox(label='Search Status', interactive=False, visible=False)
|
|
|
352 |
|
353 |
|
354 |
### OVERVIEW
|
|
|
453 |
|
454 |
# Highlight description
|
455 |
hl_desc = """
|
456 |
+
<font size="2">**<span style="color:black;background-color:#DB7262;">Red</span>**: sentences simiar to the selected sentence from submission. Darker = more similar.</font>
|
457 |
+
|
458 |
+
<font size="2">**<span style="color:black;background-color:#65B5E3;">Blue</span>**: phrases that appear in both sentences.</font>
|
459 |
+
"""
|
460 |
+
#---"""
|
|
|
|
|
461 |
# show multiple papers in radio check box to select from
|
462 |
paper_abstract = gr.Textbox(label='Abstract', interactive=False, visible=False)
|
463 |
with gr.Row():
|
|
|
477 |
with gr.Column(scale=3):
|
478 |
# selected paper and highlight
|
479 |
with gr.Row():
|
480 |
+
# slider for highlight amount
|
481 |
+
highlight_slider = gr.Slider(
|
482 |
+
label='Number of Highlighted Sentences',
|
483 |
+
minimum=1,
|
484 |
+
maximum=15,
|
485 |
+
step=1,
|
486 |
+
value=2,
|
487 |
+
visible=False
|
488 |
+
)
|
489 |
+
with gr.Row():
|
490 |
+
# highlight legend
|
491 |
highlight_legend = gr.Markdown(value=hl_desc, visible=False)
|
492 |
with gr.Row(visible=False) as title_row:
|
493 |
+
# selected paper title
|
494 |
paper_title = gr.Markdown(value='')
|
495 |
with gr.Row(visible=False) as aff_row:
|
496 |
+
# selected paper's affinity score
|
497 |
affinity = gr.Markdown(value='')
|
498 |
with gr.Row(visible=False) as hl_row:
|
499 |
# highlighted text from paper
|
500 |
highlight = gr.components.Interpretation(paper_abstract)
|
501 |
|
502 |
|
|
|
503 |
### EVENT LISTENERS
|
504 |
|
505 |
compute_btn.click(
|
|
|
564 |
demarc2,
|
565 |
search_status,
|
566 |
result2_desc,
|
567 |
+
highlight_slider,
|
568 |
info,
|
569 |
],
|
570 |
show_progress=True,
|
|
|
582 |
title_row,
|
583 |
aff_row,
|
584 |
highlight_legend,
|
585 |
+
highlight_slider,
|
586 |
+
hl_row,
|
587 |
]
|
588 |
)
|
589 |
|
|
|
593 |
inputs=[
|
594 |
selected_papers_radio,
|
595 |
source_sentences,
|
596 |
+
highlight_slider,
|
597 |
info
|
598 |
],
|
599 |
outputs=highlight
|
|
|
605 |
inputs=[
|
606 |
selected_papers_radio,
|
607 |
source_sentences,
|
608 |
+
highlight_slider,
|
609 |
info,
|
610 |
],
|
611 |
outputs= [
|
612 |
paper_title,
|
613 |
paper_abstract,
|
614 |
affinity,
|
615 |
+
highlight,
|
616 |
+
highlight_slider
|
617 |
+
]
|
618 |
+
)
|
619 |
+
|
620 |
+
highlight_slider.change(
|
621 |
+
fn=change_num_highlight,
|
622 |
+
inputs=[
|
623 |
+
selected_papers_radio,
|
624 |
+
source_sentences,
|
625 |
+
highlight_slider,
|
626 |
+
info
|
627 |
+
],
|
628 |
+
outputs=[
|
629 |
highlight
|
630 |
]
|
631 |
)
|
details.html
CHANGED
@@ -9,8 +9,8 @@ The tool is developed by <a href="https://wnstlr.github.io", target="_blank">Joo
|
|
9 |
<h1>What Happens Behind the Scenes</h1>
|
10 |
<ul>
|
11 |
<li> The tool retrieves the reviewer's previous publications using <a href="https://www.semanticscholar.org/product/api", target="_blank">Semantic Scholar API</a>.</li>
|
12 |
-
<li> The tool computes the affinity score between the submission abstract and each paper's abstract, using text representations from a <a href="https://
|
13 |
-
<li> The tool computes pairwise sentence relevance scores between the submission abstract and the reviewer paper's abstract, using text representations from
|
14 |
<li> The tool highlights overlapping words (nouns) between setence pairs using <a href="https://www.nltk.org/book/ch05.html", target="_blank">POS tagging</a>.</li>
|
15 |
</ul>
|
16 |
|
|
|
9 |
<h1>What Happens Behind the Scenes</h1>
|
10 |
<ul>
|
11 |
<li> The tool retrieves the reviewer's previous publications using <a href="https://www.semanticscholar.org/product/api", target="_blank">Semantic Scholar API</a>.</li>
|
12 |
+
<li> The tool computes the affinity score between the submission abstract and each paper's abstract, using text representations from a <a href="https://huggingface.co/allenai/specter2", target="_blank">language model fine-tuned on academic papers</a>.</li>
|
13 |
+
<li> The tool then computes pairwise sentence relevance scores between the submission abstract and the reviewer paper's abstract, using text representations from <a href="https://huggingface.co/allenai/specter2", target="_blank">the same model</a>.</li>
|
14 |
<li> The tool highlights overlapping words (nouns) between setence pairs using <a href="https://www.nltk.org/book/ch05.html", target="_blank">POS tagging</a>.</li>
|
15 |
</ul>
|
16 |
|
input_format.py
CHANGED
@@ -81,19 +81,4 @@ def get_text_from_author_id(author_id, max_count=150):
|
|
81 |
papers = data['papers'][:max_count]
|
82 |
name = data['name']
|
83 |
|
84 |
-
return name, papers
|
85 |
-
|
86 |
-
## TODO Preprocess Extracted Texts from PDFs
|
87 |
-
# Get a portion of the text for actual task
|
88 |
-
|
89 |
-
def get_title(text):
|
90 |
-
pass
|
91 |
-
|
92 |
-
def get_abstract(text):
|
93 |
-
pass
|
94 |
-
|
95 |
-
def get_introduction(text):
|
96 |
-
pass
|
97 |
-
|
98 |
-
def get_conclusion(text):
|
99 |
-
pass
|
|
|
81 |
papers = data['papers'][:max_count]
|
82 |
name = data['name']
|
83 |
|
84 |
+
return name, papers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
score.py
CHANGED
@@ -40,12 +40,9 @@ def get_top_k(score_mat, K=3):
|
|
40 |
"""
|
41 |
Pick top K sentences to show
|
42 |
"""
|
43 |
-
|
44 |
-
picked_sent =
|
45 |
-
picked_scores =
|
46 |
-
[score_mat[i,picked_sent[i]] for i in range(picked_sent.shape[0])]
|
47 |
-
)
|
48 |
-
|
49 |
return picked_sent, picked_scores
|
50 |
|
51 |
def get_words(sent):
|
@@ -57,7 +54,6 @@ def get_words(sent):
|
|
57 |
sent_start_id = [] # keep track of the word index where the new sentence starts
|
58 |
counter = 0
|
59 |
for x in sent:
|
60 |
-
#w = x.split()
|
61 |
w = word_tokenize(x)
|
62 |
nw = len(w)
|
63 |
counter += nw
|
@@ -180,11 +176,21 @@ def remove_spaces(words, attrs):
|
|
180 |
assert(len(word_out) == len(attr_out))
|
181 |
return word_out, attr_out
|
182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scores):
|
184 |
"""
|
185 |
Mark the words that are highlighted, both by in terms of sentence and phrase
|
186 |
"""
|
187 |
num_query_sent = sent_ids.shape[0]
|
|
|
188 |
num_words = len(all_words)
|
189 |
|
190 |
output = dict()
|
@@ -193,53 +199,59 @@ def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scor
|
|
193 |
|
194 |
# for each query sentence, mark the highlight information
|
195 |
for i in range(num_query_sent):
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
for sid, sscore in zip(sent_ids[i], sent_scores[i]):
|
203 |
-
#print(len(sent_start_id), sid, sid+1)
|
204 |
-
if sid+1 < len(sent_start_id):
|
205 |
-
sent_range = (sent_start_id[sid], sent_start_id[sid+1])
|
206 |
-
is_selected_sent[sent_range[0]:sent_range[1]] = 1
|
207 |
-
word_scores[sent_range[0]:sent_range[1]] = sscore
|
208 |
-
_, is_selected_phrase[sent_range[0]:sent_range[1]] = \
|
209 |
-
get_match_phrase(query_words, all_words[sent_range[0]:sent_range[1]])
|
210 |
-
else:
|
211 |
-
is_selected_sent[sent_start_id[sid]:] = 1
|
212 |
-
word_scores[sent_start_id[sid]:] = sscore
|
213 |
-
_, is_selected_phrase[sent_start_id[sid]:] = \
|
214 |
-
get_match_phrase(query_words, all_words[sent_start_id[sid]:])
|
215 |
-
|
216 |
-
# update selected phrase scores (-1 meaning a different color in gradio)
|
217 |
-
word_scores[is_selected_sent+is_selected_phrase==2] = -0.5
|
218 |
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
return output
|
226 |
|
227 |
-
def get_highlight_info(model, tokenizer, text1, text2, K=None):
|
228 |
"""
|
229 |
Get highlight information from two texts
|
230 |
"""
|
231 |
sent1 = sent_tokenize(text1) # query
|
232 |
sent2 = sent_tokenize(text2) # candidate
|
233 |
-
if K is None: # if K is not set, select based on the length of the candidate
|
234 |
-
K = int(len(sent2) / 3)
|
235 |
score_mat = compute_sentencewise_scores(model, sent1, sent2, tokenizer=tokenizer)
|
236 |
-
|
|
|
|
|
|
|
237 |
sent_ids, sent_scores = get_top_k(score_mat, K=K)
|
238 |
words2, all_words2, sent_start_id2 = get_words(sent2)
|
239 |
info = mark_words(sent1, words2, all_words2, sent_start_id2, sent_ids, sent_scores)
|
240 |
|
241 |
-
# get top sentence pairs from the query and candidate (score, index_pair)
|
242 |
-
top_pair_num = 5
|
243 |
top_pairs = []
|
244 |
ii = np.unravel_index(np.argsort(np.array(sent_scores).ravel())[-top_pair_num:], sent_scores.shape)
|
245 |
for i, j in zip(ii[0][::-1], ii[1][::-1]):
|
|
|
40 |
"""
|
41 |
Pick top K sentences to show
|
42 |
"""
|
43 |
+
picked_scores, picked_sent = torch.sort(-score_mat, axis=1)
|
44 |
+
picked_sent = picked_sent[:,:K]
|
45 |
+
picked_scores = -picked_scores[:,:K]
|
|
|
|
|
|
|
46 |
return picked_sent, picked_scores
|
47 |
|
48 |
def get_words(sent):
|
|
|
54 |
sent_start_id = [] # keep track of the word index where the new sentence starts
|
55 |
counter = 0
|
56 |
for x in sent:
|
|
|
57 |
w = word_tokenize(x)
|
58 |
nw = len(w)
|
59 |
counter += nw
|
|
|
176 |
assert(len(word_out) == len(attr_out))
|
177 |
return word_out, attr_out
|
178 |
|
179 |
+
def scale_scores(arr, vmin=0.1, vmax=1):
|
180 |
+
# rescale positive and negative attributions to be between vmin and vmax.
|
181 |
+
# while keeping 0 at 0.
|
182 |
+
pos_max, pos_min = np.max(arr[arr > 0]), np.min(arr[arr > 0])
|
183 |
+
out = (arr - pos_min) / (pos_max - pos_min) * (vmax - vmin) + vmin
|
184 |
+
idx = np.where(arr == 0.0)[0]
|
185 |
+
out[idx] = 0.0
|
186 |
+
return out
|
187 |
+
|
188 |
def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scores):
|
189 |
"""
|
190 |
Mark the words that are highlighted, both by in terms of sentence and phrase
|
191 |
"""
|
192 |
num_query_sent = sent_ids.shape[0]
|
193 |
+
num_cand_sent = sent_ids.shape[1]
|
194 |
num_words = len(all_words)
|
195 |
|
196 |
output = dict()
|
|
|
199 |
|
200 |
# for each query sentence, mark the highlight information
|
201 |
for i in range(num_query_sent):
|
202 |
+
output[i] = dict()
|
203 |
+
for j in range(1, num_cand_sent+1): # for each number of selected sentences from candidate
|
204 |
+
query_words = word_tokenize(query_sents[i])
|
205 |
+
is_selected_sent = np.zeros(num_words)
|
206 |
+
is_selected_phrase = np.zeros(num_words)
|
207 |
+
word_scores = np.zeros(num_words)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
+
# for each selected sentences from the candidate, compile information
|
210 |
+
for sid, sscore in zip(sent_ids[i][:j], sent_scores[i][:j]):
|
211 |
+
#print(len(sent_start_id), sid, sid+1)
|
212 |
+
if sid+1 < len(sent_start_id):
|
213 |
+
sent_range = (sent_start_id[sid], sent_start_id[sid+1])
|
214 |
+
is_selected_sent[sent_range[0]:sent_range[1]] = 1
|
215 |
+
word_scores[sent_range[0]:sent_range[1]] = sscore
|
216 |
+
_, is_selected_phrase[sent_range[0]:sent_range[1]] = \
|
217 |
+
get_match_phrase(query_words, all_words[sent_range[0]:sent_range[1]])
|
218 |
+
else:
|
219 |
+
is_selected_sent[sent_start_id[sid]:] = 1
|
220 |
+
word_scores[sent_start_id[sid]:] = sscore
|
221 |
+
_, is_selected_phrase[sent_start_id[sid]:] = \
|
222 |
+
get_match_phrase(query_words, all_words[sent_start_id[sid]:])
|
223 |
+
|
224 |
+
# scale the word_scores: maximum value gets the darkest, minimum value gets the lightest color
|
225 |
+
if j > 1:
|
226 |
+
word_scores = scale_scores(word_scores)
|
227 |
+
|
228 |
+
# update selected phrase scores (-1 meaning a different color in gradio)
|
229 |
+
word_scores[is_selected_sent+is_selected_phrase==2] = -0.5
|
230 |
+
|
231 |
+
output[i][j] = {
|
232 |
+
'is_selected_sent': is_selected_sent,
|
233 |
+
'is_selected_phrase': is_selected_phrase,
|
234 |
+
'scores': word_scores
|
235 |
+
}
|
236 |
|
237 |
return output
|
238 |
|
239 |
+
def get_highlight_info(model, tokenizer, text1, text2, K=None, top_pair_num=5):
|
240 |
"""
|
241 |
Get highlight information from two texts
|
242 |
"""
|
243 |
sent1 = sent_tokenize(text1) # query
|
244 |
sent2 = sent_tokenize(text2) # candidate
|
|
|
|
|
245 |
score_mat = compute_sentencewise_scores(model, sent1, sent2, tokenizer=tokenizer)
|
246 |
+
|
247 |
+
if K is None: # if K is not set, get all information
|
248 |
+
K = score_mat.shape[1]
|
249 |
+
|
250 |
sent_ids, sent_scores = get_top_k(score_mat, K=K)
|
251 |
words2, all_words2, sent_start_id2 = get_words(sent2)
|
252 |
info = mark_words(sent1, words2, all_words2, sent_start_id2, sent_ids, sent_scores)
|
253 |
|
254 |
+
# get top sentence pairs from the query and candidate (score, index_pair) to show upfront
|
|
|
255 |
top_pairs = []
|
256 |
ii = np.unravel_index(np.argsort(np.array(sent_scores).ravel())[-top_pair_num:], sent_scores.shape)
|
257 |
for i, j in zip(ii[0][::-1], ii[1][::-1]):
|