|
from typing import AnyStr, Dict |
|
|
|
import itertools |
|
import streamlit as st |
|
import en_core_web_lg |
|
|
|
import torch.nn.parameter |
|
from bs4 import BeautifulSoup |
|
import numpy as np |
|
import base64 |
|
|
|
from spacy_streamlit.util import get_svg |
|
from streamlit.proto.SessionState_pb2 import SessionState |
|
|
|
from custom_renderer import render_sentence_custom |
|
from sentence_transformers import SentenceTransformer |
|
|
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
from transformers import pipeline |
|
import os |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; |
|
margin-bottom: 2.5rem">{}</div> """ |
|
|
|
|
|
@st.experimental_singleton |
|
def get_sentence_embedding_model(): |
|
return SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
|
|
|
|
|
@st.experimental_singleton |
|
def get_spacy(): |
|
nlp = en_core_web_lg.load() |
|
return nlp |
|
|
|
|
|
@st.experimental_singleton |
|
def get_transformer_pipeline(): |
|
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english") |
|
model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english") |
|
return pipeline("ner", model=model, tokenizer=tokenizer, grouped_entities=True) |
|
|
|
|
|
@st.experimental_singleton |
|
def get_summarizer_model(): |
|
model_name = 'google/pegasus-cnn_dailymail' |
|
summarizer_model = pipeline("summarization", model=model_name, tokenizer=model_name, |
|
device=0 if torch.cuda.is_available() else -1) |
|
|
|
return summarizer_model |
|
|
|
|
|
|
|
st.set_page_config( |
|
page_title="📜 Hallucination detection in summaries 📜", |
|
page_icon="", |
|
layout="centered", |
|
initial_sidebar_state="auto", |
|
menu_items={ |
|
'Get help': None, |
|
'Report a bug': None, |
|
'About': None, |
|
} |
|
) |
|
|
|
|
|
def list_all_article_names() -> list: |
|
filenames = [] |
|
for file in sorted(os.listdir('./sample-articles/')): |
|
if file.endswith('.txt'): |
|
filenames.append(file.replace('.txt', '')) |
|
|
|
filenames.append("Provide your own input") |
|
return filenames |
|
|
|
|
|
def fetch_article_contents(filename: str) -> AnyStr: |
|
if filename == "Provide your own input": |
|
return " " |
|
with open(f'./sample-articles/{filename}.txt', 'r') as f: |
|
data = f.read() |
|
return data |
|
|
|
|
|
def fetch_summary_contents(filename: str) -> AnyStr: |
|
with open(f'./sample-summaries/{filename}.txt', 'r') as f: |
|
data = f.read() |
|
return data |
|
|
|
|
|
def fetch_entity_specific_contents(filename: str) -> AnyStr: |
|
with open(f'./entity-specific-text/{filename}.txt', 'r') as f: |
|
data = f.read() |
|
return data |
|
|
|
|
|
def fetch_dependency_specific_contents(filename: str) -> AnyStr: |
|
with open(f'./dependency-specific-text/{filename}.txt', 'r') as f: |
|
data = f.read() |
|
return data |
|
|
|
|
|
def fetch_ranked_summaries(filename: str, ranknumber: int) -> AnyStr: |
|
with open(f'./ranked-summaries/{filename}/Rank{ranknumber}.txt', 'r') as f: |
|
data = f.read() |
|
return data |
|
|
|
|
|
def fetch_dependency_svg(filename: str) -> AnyStr: |
|
with open(f'./dependency-images/{filename}.txt', 'r') as f: |
|
lines = [line.rstrip() for line in f] |
|
return lines |
|
|
|
|
|
def display_summary(summary_content: str): |
|
st.session_state.summary_output = summary_content |
|
soup = BeautifulSoup(summary_content, features="html.parser") |
|
return HTML_WRAPPER.format(soup) |
|
|
|
|
|
def get_all_entities_per_sentence(text): |
|
doc = nlp(text) |
|
|
|
sentences = list(doc.sents) |
|
|
|
entities_all_sentences = [] |
|
for sentence in sentences: |
|
entities_this_sentence = [] |
|
|
|
|
|
for entity in sentence.ents: |
|
entities_this_sentence.append(str(entity)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
entities_xlm = [entity["word"] for entity in ner_model(str(sentence))] |
|
for entity in entities_xlm: |
|
entities_this_sentence.append(str(entity)) |
|
|
|
entities_all_sentences.append(entities_this_sentence) |
|
|
|
return entities_all_sentences |
|
|
|
|
|
def get_all_entities(text): |
|
all_entities_per_sentence = get_all_entities_per_sentence(text) |
|
return list(itertools.chain.from_iterable(all_entities_per_sentence)) |
|
|
|
|
|
def get_and_compare_entities(first_time: bool): |
|
if first_time: |
|
article_content = st.session_state.article_text |
|
all_entities_per_sentence = get_all_entities_per_sentence(article_content) |
|
entities_article = list(itertools.chain.from_iterable(all_entities_per_sentence)) |
|
st.session_state.entities_article = entities_article |
|
else: |
|
entities_article = st.session_state.entities_article |
|
|
|
summary_content = st.session_state.summary_output |
|
all_entities_per_sentence = get_all_entities_per_sentence(summary_content) |
|
entities_summary = list(itertools.chain.from_iterable(all_entities_per_sentence)) |
|
|
|
matched_entities = [] |
|
unmatched_entities = [] |
|
for entity in entities_summary: |
|
if any(entity.lower() in substring_entity.lower() for substring_entity in entities_article): |
|
matched_entities.append(entity) |
|
elif any( |
|
np.inner(sentence_embedding_model.encode(entity, show_progress_bar=False), |
|
sentence_embedding_model.encode(art_entity, show_progress_bar=False)) > 0.9 for |
|
art_entity in entities_article): |
|
matched_entities.append(entity) |
|
else: |
|
unmatched_entities.append(entity) |
|
|
|
matched_entities = list(dict.fromkeys(matched_entities)) |
|
unmatched_entities = list(dict.fromkeys(unmatched_entities)) |
|
|
|
matched_entities_to_remove = [] |
|
unmatched_entities_to_remove = [] |
|
|
|
for entity in matched_entities: |
|
for substring_entity in matched_entities: |
|
if entity != substring_entity and entity.lower() in substring_entity.lower(): |
|
matched_entities_to_remove.append(entity) |
|
|
|
for entity in unmatched_entities: |
|
for substring_entity in unmatched_entities: |
|
if entity != substring_entity and entity.lower() in substring_entity.lower(): |
|
unmatched_entities_to_remove.append(entity) |
|
|
|
matched_entities_to_remove = list(dict.fromkeys(matched_entities_to_remove)) |
|
unmatched_entities_to_remove = list(dict.fromkeys(unmatched_entities_to_remove)) |
|
|
|
for entity in matched_entities_to_remove: |
|
matched_entities.remove(entity) |
|
for entity in unmatched_entities_to_remove: |
|
unmatched_entities.remove(entity) |
|
|
|
return matched_entities, unmatched_entities |
|
|
|
|
|
def highlight_entities(): |
|
summary_content = st.session_state.summary_output |
|
markdown_start_red = "<mark class=\"entity\" style=\"background: rgb(238, 135, 135);\">" |
|
markdown_start_green = "<mark class=\"entity\" style=\"background: rgb(121, 236, 121);\">" |
|
markdown_end = "</mark>" |
|
|
|
matched_entities, unmatched_entities = get_and_compare_entities(True) |
|
|
|
for entity in matched_entities: |
|
summary_content = summary_content.replace(entity, markdown_start_green + entity + markdown_end) |
|
|
|
for entity in unmatched_entities: |
|
summary_content = summary_content.replace(entity, markdown_start_red + entity + markdown_end) |
|
soup = BeautifulSoup(summary_content, features="html.parser") |
|
return HTML_WRAPPER.format(soup) |
|
|
|
|
|
def render_dependency_parsing(text: Dict): |
|
html = render_sentence_custom(text, nlp) |
|
html = html.replace("\n\n", "\n") |
|
st.write(get_svg(html), unsafe_allow_html=True) |
|
|
|
|
|
def check_dependency(article: bool): |
|
if article: |
|
text = st.session_state.article_text |
|
all_entities = get_all_entities_per_sentence(text) |
|
else: |
|
text = st.session_state.summary_output |
|
all_entities = get_all_entities_per_sentence(text) |
|
doc = nlp(text) |
|
tok_l = doc.to_json()['tokens'] |
|
test_list_dict_output = [] |
|
|
|
sentences = list(doc.sents) |
|
for i, sentence in enumerate(sentences): |
|
start_id = sentence.start |
|
end_id = sentence.end |
|
for t in tok_l: |
|
if t["id"] < start_id or t["id"] > end_id: |
|
continue |
|
head = tok_l[t['head']] |
|
if t['dep'] == 'amod' or t['dep'] == "pobj": |
|
object_here = text[t['start']:t['end']] |
|
object_target = text[head['start']:head['end']] |
|
if t['dep'] == "pobj" and str.lower(object_target) != "in": |
|
continue |
|
|
|
if object_here in all_entities[i]: |
|
identifier = object_here + t['dep'] + object_target |
|
test_list_dict_output.append({"dep": t['dep'], "cur_word_index": (t['id'] - sentence.start), |
|
"target_word_index": (t['head'] - sentence.start), |
|
"identifier": identifier, "sentence": str(sentence)}) |
|
elif object_target in all_entities[i]: |
|
identifier = object_here + t['dep'] + object_target |
|
test_list_dict_output.append({"dep": t['dep'], "cur_word_index": (t['id'] - sentence.start), |
|
"target_word_index": (t['head'] - sentence.start), |
|
"identifier": identifier, "sentence": str(sentence)}) |
|
else: |
|
continue |
|
return test_list_dict_output |
|
|
|
|
|
def render_svg(svg_file): |
|
with open(svg_file, "r") as f: |
|
lines = f.readlines() |
|
svg = "".join(lines) |
|
|
|
|
|
b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") |
|
html = r'<img src="data:image/svg+xml;base64,%s"/>' % b64 |
|
return html |
|
|
|
|
|
def generate_abstractive_summary(text, type, min_len=120, max_len=512, **kwargs): |
|
text = text.strip().replace("\n", " ") |
|
if type == "top_p": |
|
text = summarization_model(text, min_length=min_len, |
|
max_length=max_len, |
|
top_k=50, top_p=0.95, clean_up_tokenization_spaces=True, truncation=True, **kwargs) |
|
elif type == "greedy": |
|
text = summarization_model(text, min_length=min_len, |
|
max_length=max_len, clean_up_tokenization_spaces=True, truncation=True, **kwargs) |
|
elif type == "top_k": |
|
text = summarization_model(text, min_length=min_len, max_length=max_len, top_k=50, |
|
clean_up_tokenization_spaces=True, truncation=True, **kwargs) |
|
elif type == "beam": |
|
text = summarization_model(text, min_length=min_len, |
|
max_length=max_len, |
|
clean_up_tokenization_spaces=True, truncation=True, **kwargs) |
|
summary = text[0]['summary_text'].replace("<n>", " ") |
|
return summary |
|
|
|
|
|
|
|
sentence_embedding_model = get_sentence_embedding_model() |
|
ner_model = get_transformer_pipeline() |
|
nlp = get_spacy() |
|
summarization_model = get_summarizer_model() |
|
|
|
|
|
st.title('📜 Hallucination detection 📜') |
|
st.subheader("🔎 Detecting errors in generated abstractive summaries") |
|
|
|
|
|
|
|
st.header("🧑🏫 Introduction") |
|
|
|
|
|
|
|
st.markdown(""" |
|
Recent work using 🤖 **transformers** 🤖 on large text corpora has shown great success when fine-tuned on |
|
several different downstream NLP tasks. One such task is that of text summarization. The goal of text summarization |
|
is to generate concise and accurate summaries from input document(s). There are 2 types of summarization: |
|
|
|
- **Extractive summarization** merely copies informative fragments from the input |
|
- **Abstractive summarization** may generate novel words. A good abstractive summary should cover principal |
|
information in the input and has to be linguistically fluent. This interactive blogpost will focus on this more difficult task of |
|
abstractive summary generation. Furthermore we will focus on factual errors in summaries, and less sentence fluency.""") |
|
|
|
st.markdown("###") |
|
st.markdown("🤔 **Why is this important?** 🤔 Let's say we want to summarize news articles for a popular " |
|
"newspaper. If an article tells the story of Elon Musk buying **Twitter**, we don't want our summarization " |
|
"model to say that he bought **Facebook** instead. Summarization could also be done for financial reports " |
|
"for example. In such environments, these errors can be very critical, so we want to find a way to " |
|
"detect them.") |
|
st.markdown("###") |
|
st.markdown("""To generate summaries we will use the 🐎 [PEGASUS](https://huggingface.co/google/pegasus-cnn_dailymail) 🐎 |
|
model, producing abstractive summaries from large articles. These summaries often contain sentences with different |
|
kinds of errors. Rather than improving the core model, we will look into possible post-processing steps to detect errors |
|
from the generated summaries. Throughout this blog, we will also explain the results for some methods on specific |
|
examples. These text blocks will be indicated and they change according to the currently selected article.""") |
|
|
|
|
|
st.header("🪶 Generating summaries") |
|
st.markdown("Let’s start by selecting an article text for which we want to generate a summary, or you can provide " |
|
"text yourself. Note that it’s suggested to provide a sufficiently large article, as otherwise the " |
|
"summary generated from it might not be optimal, leading to suboptimal performance of the post-processing " |
|
"steps. However, too long articles will be truncated and might miss information in the summary.") |
|
|
|
st.markdown("####") |
|
selected_article = st.selectbox('Select an article or provide your own:', |
|
list_all_article_names(), index=2) |
|
st.session_state.article_text = fetch_article_contents(selected_article) |
|
article_text = st.text_area( |
|
label='Full article text', |
|
value=st.session_state.article_text, |
|
height=250 |
|
) |
|
|
|
summarize_button = st.button(label='🤯 Process article content', |
|
help="Start interactive blogpost") |
|
|
|
if summarize_button: |
|
st.session_state.article_text = article_text |
|
st.markdown("####") |
|
st.markdown( |
|
"*Below you can find the generated summary for the article. We will discuss two approaches that we found are " |
|
"able to detect some common errors. Based on these errors, one could then score different summaries, indicating how " |
|
"factual a summary is for a given article. The idea is that in production, you could generate a set of " |
|
"summaries for the same article, with different parameters (or even different models). By using " |
|
"post-processing error detection, we can then select the best possible summary.*") |
|
st.markdown("####") |
|
if st.session_state.article_text: |
|
with st.spinner('Generating summary, this might take a while...'): |
|
if selected_article != "Provide your own input" and article_text == fetch_article_contents( |
|
selected_article): |
|
st.session_state.unchanged_text = True |
|
summary_content = fetch_summary_contents(selected_article) |
|
else: |
|
summary_content = generate_abstractive_summary(article_text, type="beam", do_sample=True, num_beams=15, |
|
no_repeat_ngram_size=4) |
|
st.session_state.unchanged_text = False |
|
summary_displayed = display_summary(summary_content) |
|
st.write("✍ **Generated summary:** ✍", summary_displayed, unsafe_allow_html=True) |
|
else: |
|
st.error('**Error**: No comment to classify. Please provide a comment.') |
|
|
|
|
|
st.header("1️⃣ Entity matching") |
|
st.markdown("The first method we will discuss is called **Named Entity Recognition** (NER). NER is the task of " |
|
"identifying and categorising key information (entities) in text. An entity can be a singular word or a " |
|
"series of words that consistently refers to the same thing. Common entity classes are person names, " |
|
"organisations, locations and so on. By applying NER to both the article and its summary, we can spot " |
|
"possible **hallucinations**. ") |
|
|
|
st.markdown("Hallucinations are words generated by the model that are not supported by " |
|
"the source input. Deep learning based generation is [prone to hallucinate](" |
|
"https://arxiv.org/pdf/2202.03629.pdf) unintended text. These hallucinations degrade " |
|
"system performance and fail to meet user expectations in many real-world scenarios. By applying entity matching, we can improve this problem" |
|
" for the downstream task of summary generation.") |
|
|
|
st.markdown(" In theory all entities in the summary (such as dates, locations and so on), " |
|
"should also be present in the article. Thus we can extract all entities from the summary and compare " |
|
"them to the entities of the original article, spotting potential hallucinations. The more unmatched " |
|
"entities we find, the lower the factualness score of the summary. ") |
|
with st.spinner("Calculating and matching entities, this takes about 10-20 seconds..."): |
|
entity_match_html = highlight_entities() |
|
st.markdown("####") |
|
st.write(entity_match_html, unsafe_allow_html=True) |
|
red_text = """<font color="black"><span style="background-color: rgb(238, 135, 135); opacity: |
|
1;">red</span></font> """ |
|
green_text = """<font color="black"> |
|
<span style="background-color: rgb(121, 236, 121); opacity: 1;">green</span> |
|
</font>""" |
|
|
|
markdown_start_red = "<mark class=\"entity\" style=\"background: rgb(238, 135, 135);\">" |
|
markdown_start_green = "<mark class=\"entity\" style=\"background: rgb(121, 236, 121);\">" |
|
st.markdown( |
|
"We call this technique **entity matching** and here you can see what this looks like when we apply this " |
|
"method on the summary. Entities in the summary are marked " + green_text + " when the entity also " |
|
"exists in the article, " |
|
"while unmatched entities " |
|
"are marked " + red_text + |
|
". Several of the example articles and their summaries indicate different errors we find by using this " |
|
"technique. Based on the current article, we provide a short explanation of the results below **(only for " |
|
"example articles)**. ", unsafe_allow_html=True) |
|
if st.session_state.unchanged_text: |
|
entity_specific_text = fetch_entity_specific_contents(selected_article) |
|
soup = BeautifulSoup(entity_specific_text, features="html.parser") |
|
st.markdown("####") |
|
st.write("💡👇 **Specific example explanation** 👇💡", HTML_WRAPPER.format(soup), unsafe_allow_html=True) |
|
|
|
|
|
st.header("2️⃣ Dependency comparison") |
|
st.markdown( |
|
"The second method we use for post-processing is called **Dependency parsing**: the process in which the " |
|
"grammatical structure in a sentence is analysed, to find out related words as well as the type of the " |
|
"relationship between them. For the sentence “Jan’s wife is called Sarah” you would get the following " |
|
"dependency graph:") |
|
|
|
|
|
|
|
st.write(render_svg('ExampleParsing.svg'), unsafe_allow_html=True) |
|
st.markdown( |
|
"Here, *“Jan”* is the *“poss”* (possession modifier) of *“wife”*. If suddenly the summary would read *“Jan’s" |
|
" husband…”*, there would be a dependency in the summary that is non-existent in the article itself (namely " |
|
"*“Jan”* is the “poss” of *“husband”*)." |
|
"However, often new dependencies are introduced in the summary that " |
|
"are still correct, as can be seen in the example below. ") |
|
st.write(render_svg('SecondExampleParsing.svg'), unsafe_allow_html=True) |
|
|
|
st.markdown("*“The borders of Ukraine”* have a different dependency between *“borders”* and " |
|
"*“Ukraine”* " |
|
"than *“Ukraine’s borders”*, while both descriptions have the same meaning. So just matching all " |
|
"dependencies between article and summary (as we did with entity matching) would not be a robust method." |
|
" More on the different sorts of dependencies and their description can be found [here](https://universaldependencies.org/docs/en/dep/).") |
|
st.markdown("However, we have found that **there are specific dependencies that are often an " |
|
"indication of a wrongly constructed sentence** -when there is no article match. We (currently) use 2 " |
|
"common dependencies which - when present in the summary but not in the article - are highly " |
|
"indicative of factualness errors. " |
|
"Furthermore, we only check dependencies between an existing **entity** and its direct connections. " |
|
"Below we highlight all unmatched dependencies that satisfy the discussed constraints. We also " |
|
"discuss the specific results for the currently selected example article.") |
|
with st.spinner("Doing dependency parsing..."): |
|
if st.session_state.unchanged_text: |
|
for cur_svg_image in fetch_dependency_svg(selected_article): |
|
st.write(cur_svg_image, unsafe_allow_html=True) |
|
dep_specific_text = fetch_dependency_specific_contents(selected_article) |
|
soup = BeautifulSoup(dep_specific_text, features="html.parser") |
|
st.write("💡👇 **Specific example explanation** 👇💡", HTML_WRAPPER.format(soup), unsafe_allow_html=True) |
|
else: |
|
summary_deps = check_dependency(False) |
|
article_deps = check_dependency(True) |
|
total_unmatched_deps = [] |
|
for summ_dep in summary_deps: |
|
if not any(summ_dep['identifier'] in art_dep['identifier'] for art_dep in article_deps): |
|
total_unmatched_deps.append(summ_dep) |
|
if total_unmatched_deps: |
|
for current_drawing_list in total_unmatched_deps: |
|
render_dependency_parsing(current_drawing_list) |
|
|
|
|
|
|
|
st.header("🤝 Bringing it together") |
|
st.markdown("We have presented 2 methods that try to detect errors in summaries via post-processing steps. Entity " |
|
"matching can be used to solve hallucinations, while dependency comparison can be used to filter out " |
|
"some bad sentences (and thus worse summaries). These methods highlight the possibilities of " |
|
"post-processing AI-made summaries, but are only a first introduction. As the methods were " |
|
"empirically tested they are definitely not sufficiently robust for general use-cases.") |
|
st.markdown("####") |
|
st.markdown( |
|
"Below we generate 3 different kind of summaries, and based on the two discussed methods, their errors are " |
|
"detected to estimate a factualness score. Based on this basic approach, " |
|
"the best summary (read: the one that a human would prefer or indicate as the best one) " |
|
"will hopefully be at the top. Summaries with the same scores will get the same rank displayed. We currently " |
|
"only do this for the example articles (for which the different summmaries are already generated). The reason " |
|
"for this is that HuggingFace spaces are limited in their CPU memory.") |
|
st.markdown("####") |
|
|
|
if selected_article != "Provide your own input" and article_text == fetch_article_contents(selected_article): |
|
with st.spinner("Calculating more summaries and scoring them, this might take a minute or two..."): |
|
summaries_list = [] |
|
deduction_points = [] |
|
|
|
|
|
for i in range(1 , 4): |
|
st.session_state.summary_output = fetch_ranked_summaries(selected_article, i) |
|
_, amount_unmatched = get_and_compare_entities(False) |
|
|
|
summary_deps = check_dependency(False) |
|
article_deps = check_dependency(True) |
|
total_unmatched_deps = [] |
|
for summ_dep in summary_deps: |
|
if not any(summ_dep['identifier'] in art_dep['identifier'] for art_dep in article_deps): |
|
total_unmatched_deps.append(summ_dep) |
|
|
|
summaries_list.append(st.session_state.summary_output) |
|
deduction_points.append(len(amount_unmatched) + len(total_unmatched_deps)) |
|
|
|
|
|
|
|
deduction_points, summaries_list = (list(t) for t in zip(*sorted(zip(deduction_points, summaries_list)))) |
|
|
|
cur_rank = 1 |
|
rank_downgrade = 0 |
|
for i in range(len(deduction_points)): |
|
st.write(f'🏆 Rank {cur_rank} summary: 🏆', display_summary(summaries_list[i]), unsafe_allow_html=True) |
|
if i < len(deduction_points) - 1: |
|
rank_downgrade += 1 |
|
if not deduction_points[i + 1] == deduction_points[i]: |
|
cur_rank += rank_downgrade |
|
rank_downgrade = 0 |
|
|