from typing import AnyStr, Dict |
import itertools |
import streamlit as st |
import en_core_web_lg |
import torch.nn.parameter |
from bs4 import BeautifulSoup |
import numpy as np |
import base64 |
from spacy_streamlit.util import get_svg |
from custom_renderer import render_sentence_custom |
from sentence_transformers import SentenceTransformer |
from transformers import AutoTokenizer, AutoModelForTokenClassification |
from transformers import pipeline |
import os |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; |
margin-bottom: 2.5rem">{}</div> """ |
@st.experimental_singleton |
def get_sentence_embedding_model(): |
return SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
@st.experimental_singleton |
def get_spacy(): |
nlp = en_core_web_lg.load() |
return nlp |
@st.experimental_singleton |
def get_transformer_pipeline(): |
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english") |
model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english") |
return pipeline("ner", model=model, tokenizer=tokenizer, grouped_entities=True) |
@st.experimental_singleton |
def get_summarizer_model(): |
model_name = 'google/pegasus-cnn_dailymail' |
summarizer_model = pipeline("summarization", model=model_name, tokenizer=model_name, |
device=0 if torch.cuda.is_available() else -1) |
return summarizer_model |
st.set_page_config( |
page_title="📜 Post-processing summarization fact checker 📜", |
page_icon="", |
layout="centered", |
initial_sidebar_state="auto", |
menu_items={ |
'Get help': None, |
'Report a bug': None, |
'About': None, |
} |
) |
def list_all_article_names() -> list: |
filenames = [] |
for file in sorted(os.listdir('./sample-articles/')): |
if file.endswith('.txt'): |
filenames.append(file.replace('.txt', '')) |
filenames.append("Provide your own input") |
return filenames |
def fetch_article_contents(filename: str) -> AnyStr: |
if filename == "Provide your own input": |
return " " |
with open(f'./sample-articles/{filename.lower()}.txt', 'r') as f: |
data = f.read() |
return data |
def fetch_summary_contents(filename: str) -> AnyStr: |
with open(f'./sample-summaries/{filename.lower()}.txt', 'r') as f: |
data = f.read() |
return data |
def fetch_entity_specific_contents(filename: str) -> AnyStr: |
with open(f'./entity-specific-text/{filename.lower()}.txt', 'r') as f: |
data = f.read() |
return data |
def fetch_dependency_specific_contents(filename: str) -> AnyStr: |
with open(f'./dependency-specific-text/{filename.lower()}.txt', 'r') as f: |
data = f.read() |
return data |
def fetch_dependency_svg(filename: str) -> AnyStr: |
with open(f'./dependency-images/{filename.lower()}.txt', 'r') as f: |
lines = [line.rstrip() for line in f] |
return lines |
def display_summary(summary_content: str): |
st.session_state.summary_output = summary_content |
soup = BeautifulSoup(summary_content, features="html.parser") |
return HTML_WRAPPER.format(soup) |
def get_all_entities_per_sentence(text): |
doc = nlp(text) |
sentences = list(doc.sents) |
entities_all_sentences = [] |
for sentence in sentences: |
entities_this_sentence = [] |
for entity in sentence.ents: |
entities_this_sentence.append(str(entity)) |
entities_xlm = [entity["word"] for entity in ner_model(str(sentence))] |
for entity in entities_xlm: |
entities_this_sentence.append(str(entity)) |
entities_all_sentences.append(entities_this_sentence) |
return entities_all_sentences |
def get_all_entities(text): |
all_entities_per_sentence = get_all_entities_per_sentence(text) |
return list(itertools.chain.from_iterable(all_entities_per_sentence)) |
def get_and_compare_entities(): |
article_content = st.session_state.article_text |
all_entities_per_sentence = get_all_entities_per_sentence(article_content) |
entities_article = list(itertools.chain.from_iterable(all_entities_per_sentence)) |
summary_content = st.session_state.summary_output |
all_entities_per_sentence = get_all_entities_per_sentence(summary_content) |
entities_summary = list(itertools.chain.from_iterable(all_entities_per_sentence)) |
matched_entities = [] |
unmatched_entities = [] |
for entity in entities_summary: |
if any(entity.lower() in substring_entity.lower() for substring_entity in entities_article): |
matched_entities.append(entity) |
elif any( |
np.inner(sentence_embedding_model.encode(entity, show_progress_bar=False), |
sentence_embedding_model.encode(art_entity, show_progress_bar=False)) > 0.9 for |
art_entity in entities_article): |
matched_entities.append(entity) |
else: |
unmatched_entities.append(entity) |
return matched_entities, unmatched_entities |
def highlight_entities(): |
summary_content = st.session_state.summary_output |
markdown_start_red = "<mark class=\"entity\" style=\"background: rgb(238, 135, 135);\">" |
markdown_start_green = "<mark class=\"entity\" style=\"background: rgb(121, 236, 121);\">" |
markdown_end = "</mark>" |
matched_entities, unmatched_entities = get_and_compare_entities() |
for entity in matched_entities: |
summary_content = summary_content.replace(entity, markdown_start_green + entity + markdown_end) |
for entity in unmatched_entities: |
summary_content = summary_content.replace(entity, markdown_start_red + entity + markdown_end) |
soup = BeautifulSoup(summary_content, features="html.parser") |
return HTML_WRAPPER.format(soup) |
def render_dependency_parsing(text: Dict): |
html = render_sentence_custom(text, nlp) |
html = html.replace("\n\n", "\n") |
st.write(get_svg(html), unsafe_allow_html=True) |
def check_dependency(article: bool): |
if article: |
text = st.session_state.article_text |
all_entities = get_all_entities_per_sentence(text) |
else: |
text = st.session_state.summary_output |
all_entities = get_all_entities_per_sentence(text) |
doc = nlp(text) |
tok_l = doc.to_json()['tokens'] |
test_list_dict_output = [] |
sentences = list(doc.sents) |
for i, sentence in enumerate(sentences): |
start_id = sentence.start |
end_id = sentence.end |
for t in tok_l: |
if t["id"] < start_id or t["id"] > end_id: |
continue |
head = tok_l[t['head']] |
if t['dep'] == 'amod' or t['dep'] == "pobj": |
object_here = text[t['start']:t['end']] |
object_target = text[head['start']:head['end']] |
if t['dep'] == "pobj" and str.lower(object_target) != "in": |
continue |
if object_here in all_entities[i]: |
identifier = object_here + t['dep'] + object_target |
test_list_dict_output.append({"dep": t['dep'], "cur_word_index": (t['id'] - sentence.start), |
"target_word_index": (t['head'] - sentence.start), |
"identifier": identifier, "sentence": str(sentence)}) |
elif object_target in all_entities[i]: |
identifier = object_here + t['dep'] + object_target |
test_list_dict_output.append({"dep": t['dep'], "cur_word_index": (t['id'] - sentence.start), |
"target_word_index": (t['head'] - sentence.start), |
"identifier": identifier, "sentence": str(sentence)}) |
else: |
continue |
return test_list_dict_output |
def render_svg(svg_file): |
with open(svg_file, "r") as f: |
lines = f.readlines() |
svg = "".join(lines) |
b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") |
html = r'<img src="data:image/svg+xml;base64,%s"/>' % b64 |
return html |
def generate_abstractive_summary(text, type, min_len=120, max_len=512, **kwargs): |
text = text.strip().replace("\n", " ") |
if type == "top_p": |
text = summarization_model(text, min_length=min_len, |
max_length=max_len, |
top_k=50, top_p=0.95, clean_up_tokenization_spaces=True, truncation=True, **kwargs) |
elif type == "greedy": |
text = summarization_model(text, min_length=min_len, |
max_length=max_len, clean_up_tokenization_spaces=True, truncation=True, **kwargs) |
elif type == "top_k": |
text = summarization_model(text, min_length=min_len, max_length=max_len, top_k=50, |
clean_up_tokenization_spaces=True, truncation=True, **kwargs) |
elif type == "beam": |
text = summarization_model(text, min_length=min_len, |
max_length=max_len, |
clean_up_tokenization_spaces=True, truncation=True, **kwargs) |
summary = text[0]['summary_text'].replace("<n>", " ") |
return summary |
st.title('📜 Summarization fact checker 📜') |
st.header("Introduction") |
st.markdown("""Recent work using transformers on large text corpora has shown great success when fine-tuned on |
several different downstream NLP tasks. One such task is that of text summarization. The goal of text summarization |
is to generate concise and accurate summaries from input document(s). There are 2 types of summarization: extractive |
and abstractive. **Extractive summarization** merely copies informative fragments from the input, |
whereas **abstractive summarization** may generate novel words. A good abstractive summary should cover principal |
information in the input and has to be linguistically fluent. This interactive blogpost will focus on this more difficult task of |
abstractive summary generation.""") |
st.markdown("""To generate summaries we will use the [PEGASUS] (https://huggingface.co/google/pegasus-cnn_dailymail) |
model, producing abstractive summaries from large articles. These summaries often contain sentences with different |
kinds of errors. Rather than improving the core model, we will look into possible post-processing steps to detect errors |
from the generated summaries. Throughout this blog, we will also explain the results for some methods on specific |
examples. These text blocks will be indicated and they change according to the currently selected article.""") |
sentence_embedding_model = get_sentence_embedding_model() |
ner_model = get_transformer_pipeline() |
nlp = get_spacy() |
summarization_model = get_summarizer_model() |
st.header("🪶 Generating summaries") |
st.markdown("Let’s start by selecting an article text for which we want to generate a summary, or you can provide " |
"text yourself. Note that it’s suggested to provide a sufficiently large article, as otherwise the " |
"summary generated from it might not be optimal, leading to suboptimal performance of the post-processing " |
"steps. However, too long articles will be truncated and might miss information in the summary.") |
selected_article = st.selectbox('Select an article or provide your own:', |
list_all_article_names()) |
st.session_state.article_text = fetch_article_contents(selected_article) |
article_text = st.text_area( |
label='Full article text', |
value=st.session_state.article_text, |
height=150 |
) |
summarize_button = st.button(label='Process article content', |
help="Start interactive blogpost") |
if summarize_button: |
st.session_state.article_text = article_text |
st.markdown( |
"Below you can find the generated summary for the article. We will discuss two approaches that we found are " |
"able to detect some common errors. Based on errors, one could then score different summaries, indicating how " |
"factual a summary is for a given article. The idea is that in production, you could generate a set of " |
"summaries for the same article, with different parameters (or even different models). By using " |
"post-processing error detection, we can then select the best possible summary.") |
if st.session_state.article_text: |
with st.spinner('Generating summary, this might take a while...'): |
if selected_article != "Provide your own input" and article_text == fetch_article_contents( |
selected_article): |
st.session_state.unchanged_text = True |
summary_content = fetch_summary_contents(selected_article) |
else: |
summary_content = generate_abstractive_summary(article_text, type="beam", do_sample=True, num_beams=15, |
no_repeat_ngram_size=4) |
st.session_state.unchanged_text = False |
summary_displayed = display_summary(summary_content) |
st.write("**Generated summary:**", summary_displayed, unsafe_allow_html=True) |
else: |
st.error('**Error**: No comment to classify. Please provide a comment.') |
st.header("Entity matching") |
st.markdown("The first method we will discuss is called **Named Entity Recognition** (NER). NER is the task of " |
"identifying and categorising key information (entities) in text. An entity can be a singular word or a " |
"series of words that consistently refers to the same thing. Common entity classes are person names, " |
"organisations, locations and so on. By applying NER to both the article and its summary, we can spot " |
"possible **hallucinations**. Hallucinations are words generated by the model that are not supported by " |
"the source input. In theory all entities in the summary (such as dates, locations and so on), " |
"should also be present in the article. Thus we can extract all entities from the summary and compare " |
"them to the entities of the original article, spotting potential hallucinations. The more unmatched " |
"entities we find, the lower the factualness score of the summary. ") |
with st.spinner("Calculating and matching entities..."): |
entity_match_html = highlight_entities() |
st.write(entity_match_html, unsafe_allow_html=True) |
red_text = """<font color="black"><span style="background-color: rgb(238, 135, 135); opacity: |
1;">red</span></font> """ |
green_text = """<font color="black"> |
<span style="background-color: rgb(121, 236, 121); opacity: 1;">green</span> |
</font>""" |
markdown_start_red = "<mark class=\"entity\" style=\"background: rgb(238, 135, 135);\">" |
markdown_start_green = "<mark class=\"entity\" style=\"background: rgb(121, 236, 121);\">" |
st.markdown( |
"We call this technique “entity matching” and here you can see what this looks like when we apply this " |
"method on the summary. Entities in the summary are marked " + green_text + " when the entity also " |
"exists in the article, " |
"while unmatched entities " |
"are marked " + red_text + |
". Several of the example articles and their summaries indicate different errors we find by using this " |
"technique. Based on the current article, we provide a short explanation of the results below **(only for " |
"example articles)**. ", unsafe_allow_html=True) |
if st.session_state.unchanged_text: |
entity_specific_text = fetch_entity_specific_contents(selected_article) |
soup = BeautifulSoup(entity_specific_text, features="html.parser") |
st.write("💡👇 **Specific example explanation** 👇💡", HTML_WRAPPER.format(soup), unsafe_allow_html=True) |
st.header("Dependency comparison") |
st.markdown( |
"The second method we use for post-processing is called **Dependency parsing**: the process in which the " |
"grammatical structure in a sentence is analysed, to find out related words as well as the type of the " |
"relationship between them. For the sentence “Jan’s wife is called Sarah” you would get the following " |
"dependency graph:") |
st.write(render_svg('ExampleParsing.svg'), unsafe_allow_html=True) |
st.markdown("Here, “Jan” is the “poss” (possession modifier) of “wife”. If suddenly the summary would read “Jan’s " |
"husband…”, there would be a dependency in the summary that is non-existent in the article itself (namely " |
"“Jan” is the “poss” of “husband”). However, often new dependencies are introduced in the summary that " |
"are still correct. “The borders of Ukraine” have a different dependency between “borders” and " |
"“Ukraine” " |
"than “Ukraine’s borders”, while both descriptions have the same meaning. So just matching all " |
"dependencies between article and summary (as we did with entity matching) would not be a robust method.") |
st.markdown("However, we have found that there are specific dependencies that, when unmatched, are often an " |
"indication of a wrongly constructed sentence. We found 2(/3 TODO) common dependencies which, " |
"when present in the summary but not in the article, are highly indicative of factualness errors. " |
"Furthermore, we only check dependencies between an existing **entity** and its direct connections. " |
"Below we highlight all unmatched dependencies that satisfy the discussed constraints. We also " |
"discuss the specific results for the currently selected example article.") |
with st.spinner("Doing dependency parsing..."): |
if st.session_state.unchanged_text: |
for cur_svg_image in fetch_dependency_svg(selected_article): |
st.write(cur_svg_image, unsafe_allow_html=True) |
dep_specific_text = fetch_dependency_specific_contents(selected_article) |
soup = BeautifulSoup(dep_specific_text, features="html.parser") |
st.write("💡👇 **Specific example explanation** 👇💡", HTML_WRAPPER.format(soup), unsafe_allow_html=True) |
else: |
summary_deps = check_dependency(False) |
article_deps = check_dependency(True) |
total_unmatched_deps = [] |
for summ_dep in summary_deps: |
if not any(summ_dep['identifier'] in art_dep['identifier'] for art_dep in article_deps): |
total_unmatched_deps.append(summ_dep) |
if total_unmatched_deps: |
for current_drawing_list in total_unmatched_deps: |
render_dependency_parsing(current_drawing_list) |
st.header("Wrapping up") |
st.markdown("We have presented 2 methods that try to detect errors in summaries via post-processing steps. Entity " |
"matching can be used to solve hallucinations, while dependency comparison can be used to filter out " |
"some bad sentences (and thus worse summaries). These methods highlight the possibilities of " |
"post-processing AI-made summaries, but are only a first introduction. As the methods were " |
"empirically tested they are definitely not sufficiently robust for general use-cases.") |
st.markdown("####") |
st.markdown("(TODO) Below we generated 5 different kind of summaries from the article in which their ranks are estimated, " |
"and hopefully the best summary (read: the one that a human would prefer or indicate as the best one) " |
"will be at the top. TODO: implement this (at the end I think) and also put something in the text with " |
"the actual parameters or something? ") |