MatthiasC's picture
Add more articles/summaries and custom renderer (still needs to be cleaned up and tested further
f51bffc
raw
history blame
17.2 kB
import random
from typing import AnyStr
import streamlit as st
from bs4 import BeautifulSoup
import numpy as np
import base64
from spacy_streamlit.util import get_svg
from custom_renderer import render_sentence_custom
from flair.data import Sentence
from flair.models import SequenceTagger
import spacy
from spacy import displacy
from spacy_streamlit import visualize_parser
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import pipeline
import os
from transformers_interpret import SequenceClassificationExplainer
# Map model names to URLs
model_names_to_URLs = {
'ml6team/distilbert-base-dutch-cased-toxic-comments':
'https://huggingface.co/ml6team/distilbert-base-dutch-cased-toxic-comments',
'ml6team/robbert-dutch-base-toxic-comments':
'https://huggingface.co/ml6team/robbert-dutch-base-toxic-comments',
}
about_page_markdown = f"""# 🀬 Dutch Toxic Comment Detection Space
Made by [ML6](https://ml6.eu/).
Token attribution is performed using [transformers-interpret](https://github.com/cdpierse/transformers-interpret).
"""
regular_emojis = [
'😐', 'πŸ™‚', 'πŸ‘Ά', 'πŸ˜‡',
]
undecided_emojis = [
'🀨', '🧐', 'πŸ₯Έ', 'πŸ₯΄', '🀷',
]
potty_mouth_emojis = [
'🀐', 'πŸ‘Ώ', '😑', '🀬', '☠️', '☣️', '☒️',
]
# Page setup
st.set_page_config(
page_title="Toxic Comment Detection Space",
page_icon="🀬",
layout="centered",
initial_sidebar_state="auto",
menu_items={
'Get help': None,
'Report a bug': None,
'About': about_page_markdown,
}
)
# Model setup
@st.cache(allow_output_mutation=True,
suppress_st_warning=True,
show_spinner=False)
def load_pipeline(model_name):
with st.spinner('Loading model (this might take a while)...'):
toxicity_pipeline = pipeline(
'text-classification',
model=model_name,
tokenizer=model_name)
cls_explainer = SequenceClassificationExplainer(
toxicity_pipeline.model,
toxicity_pipeline.tokenizer)
return toxicity_pipeline, cls_explainer
# Auxiliary functions
def format_explainer_html(html_string):
"""Extract tokens with attribution-based background color."""
inside_token_prefix = '##'
soup = BeautifulSoup(html_string, 'html.parser')
p = soup.new_tag('p',
attrs={'style': 'color: black; background-color: white;'})
# Select token elements and remove model specific tokens
current_word = None
for token in soup.find_all('td')[-1].find_all('mark')[1:-1]:
text = token.font.text.strip()
if text.startswith(inside_token_prefix):
text = text[len(inside_token_prefix):]
else:
# Create a new span for each word (sequence of sub-tokens)
if current_word is not None:
p.append(current_word)
p.append(' ')
current_word = soup.new_tag('span')
token.string = text
token.attrs['style'] = f"{token.attrs['style']}; padding: 0.2em 0em;"
current_word.append(token)
# Add last word
p.append(current_word)
# Add left and right-padding to each word
for span in p.find_all('span'):
span.find_all('mark')[0].attrs['style'] = (
f"{span.find_all('mark')[0].attrs['style']}; padding-left: 0.2em;")
span.find_all('mark')[-1].attrs['style'] = (
f"{span.find_all('mark')[-1].attrs['style']}; padding-right: 0.2em;")
return p
def list_all_article_names() -> list:
filenames = []
for file in os.listdir('./sample-articles/'):
if file.endswith('.txt'):
filenames.append(file.replace('.txt', ''))
return filenames
def fetch_article_contents(filename: str) -> AnyStr:
with open(f'./sample-articles/{filename.lower()}.txt', 'r') as f:
data = f.read()
return data
def fetch_summary_contents(filename: str) -> AnyStr:
with open(f'./sample-summaries/{filename.lower()}.txt', 'r') as f:
data = f.read()
return data
def classify_comment(comment, selected_model):
"""Classify the given comment and augment with additional information."""
toxicity_pipeline, cls_explainer = load_pipeline(selected_model)
result = toxicity_pipeline(comment)[0]
result['model_name'] = selected_model
# Add explanation
result['word_attribution'] = cls_explainer(comment, class_name="non-toxic")
result['visualitsation_html'] = cls_explainer.visualize()._repr_html_()
result['tokens_with_background'] = format_explainer_html(
result['visualitsation_html'])
# Choose emoji reaction
label, score = result['label'], result['score']
if label == 'toxic' and score > 0.1:
emoji = random.choice(potty_mouth_emojis)
elif label in ['non_toxic', 'non-toxic'] and score > 0.1:
emoji = random.choice(regular_emojis)
else:
emoji = random.choice(undecided_emojis)
result.update({'text': comment, 'emoji': emoji})
# Add result to session
st.session_state.results.append(result)
# Start session
if 'results' not in st.session_state:
st.session_state.results = []
# Page
# st.title('🀬 Dutch Toxic Comment Detection')
# st.markdown("""This demo showcases two Dutch toxic comment detection models.""")
#
# # Introduction
# st.markdown(f"""Both models were trained using a sequence classification task on a translated [Jigsaw Toxicity dataset](https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge) which contains toxic online comments.
# The first model is a fine-tuned multilingual [DistilBERT](https://huggingface.co/distilbert-base-multilingual-cased) model whereas the second is a fine-tuned Dutch RoBERTa-based model called [RobBERT](https://huggingface.co/pdelobelle/robbert-v2-dutch-base).""")
# st.markdown(f"""For a more comprehensive overview of the models check out their model card on πŸ€— Model Hub: [distilbert-base-dutch-toxic-comments]({model_names_to_URLs['ml6team/distilbert-base-dutch-cased-toxic-comments']}) and [RobBERT-dutch-base-toxic-comments]({model_names_to_URLs['ml6team/robbert-dutch-base-toxic-comments']}).
# """)
# st.markdown("""Enter a comment that you want to classify below. The model will determine the probability that it is toxic and highlights how much each token contributes to its decision:
# <font color="black">
# <span style="background-color: rgb(250, 219, 219); opacity: 1;">r</span><span style="background-color: rgb(244, 179, 179); opacity: 1;">e</span><span style="background-color: rgb(238, 135, 135); opacity: 1;">d</span>
# </font>
# tokens indicate toxicity whereas
# <font color="black">
# <span style="background-color: rgb(224, 251, 224); opacity: 1;">g</span><span style="background-color: rgb(197, 247, 197); opacity: 1;">re</span><span style="background-color: rgb(121, 236, 121); opacity: 1;">en</span>
# </font> tokens indicate the opposite.
#
# Try it yourself! πŸ‘‡""",
# unsafe_allow_html=True)
# Demo
# with st.form("dutch-toxic-comment-detection-input", clear_on_submit=True):
# selected_model = st.selectbox('Select a model:', model_names_to_URLs.keys(),
# )#index=0, format_func=special_internal_function, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False)
# text = st.text_area(
# label='Enter the comment you want to classify below (in Dutch):')
# _, rightmost_col = st.columns([6,1])
# submitted = rightmost_col.form_submit_button("Classify",
# help="Classify comment")
# TODO: should probably set a minimum length of article or something
selected_article = st.selectbox('Select an article or provide your own:',
list_all_article_names()) # index=0, format_func=special_internal_function, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False)
st.session_state.article_text = fetch_article_contents(selected_article)
article_text = st.text_area(
label='Full article text',
value=st.session_state.article_text,
height=250
)
# _, rightmost_col = st.columns([5, 1])
# get_summary = rightmost_col.button("Generate summary",
# help="Generate summary for the given article text")
def display_summary(article_name: str):
st.subheader("Generated summary")
# st.markdown("######")
summary_content = fetch_summary_contents(article_name)
soup = BeautifulSoup(summary_content, features="html.parser")
HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>"""
st.session_state.summary_output = HTML_WRAPPER.format(soup)
st.write(st.session_state.summary_output, unsafe_allow_html=True)
# TODO: this functionality can be cached (e.g. by storing html file output) if wanted (or just store list of entities idk)
def get_and_compare_entities_spacy(article_name: str):
nlp = spacy.load('en_core_web_lg')
article_content = fetch_article_contents(article_name)
doc = nlp(article_content)
# entities_article = doc.ents
entities_article = []
for entity in doc.ents:
entities_article.append(str(entity))
summary_content = fetch_summary_contents(article_name)
doc = nlp(summary_content)
# entities_summary = doc.ents
entities_summary = []
for entity in doc.ents:
entities_summary.append(str(entity))
matched_entities = []
unmatched_entities = []
for entity in entities_summary:
# TODO: currently substring matching but probably should do embedding method or idk?
if any(entity.lower() in substring_entity.lower() for substring_entity in entities_article):
matched_entities.append(entity)
else:
unmatched_entities.append(entity)
# print(entities_article)
# print(entities_summary)
return matched_entities, unmatched_entities
def get_and_compare_entities_flair(article_name: str):
nlp = spacy.load('en_core_web_sm')
tagger = SequenceTagger.load("flair/ner-english-ontonotes-fast")
article_content = fetch_article_contents(article_name)
doc = nlp(article_content)
entities_article = []
sentences = list(doc.sents)
for sentence in sentences:
sentence_entities = Sentence(str(sentence))
tagger.predict(sentence_entities)
for entity in sentence_entities.get_spans('ner'):
entities_article.append(entity.text)
summary_content = fetch_summary_contents(article_name)
doc = nlp(summary_content)
entities_summary = []
sentences = list(doc.sents)
for sentence in sentences:
sentence_entities = Sentence(str(sentence))
tagger.predict(sentence_entities)
for entity in sentence_entities.get_spans('ner'):
entities_summary.append(entity.text)
matched_entities = []
unmatched_entities = []
for entity in entities_summary:
# TODO: currently substring matching but probably should do embedding method or idk?
if any(entity.lower() in substring_entity.lower() for substring_entity in entities_article):
matched_entities.append(entity)
else:
unmatched_entities.append(entity)
# print(entities_article)
# print(entities_summary)
return matched_entities, unmatched_entities
def highlight_entities(article_name: str):
st.subheader("Match entities with article")
# st.markdown("####")
summary_content = fetch_summary_contents(article_name)
markdown_start_red = "<mark class=\"entity\" style=\"background: rgb(238, 135, 135);\">"
markdown_start_green = "<mark class=\"entity\" style=\"background: rgb(121, 236, 121);\">"
markdown_end = "</mark>"
matched_entities, unmatched_entities = get_and_compare_entities_spacy(article_name)
for entity in matched_entities:
summary_content = summary_content.replace(entity, markdown_start_green + entity + markdown_end)
for entity in unmatched_entities:
summary_content = summary_content.replace(entity, markdown_start_red + entity + markdown_end)
soup = BeautifulSoup(summary_content, features="html.parser")
HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>"""
st.write(HTML_WRAPPER.format(soup), unsafe_allow_html=True)
def render_dependency_parsing(text: str):
nlp = spacy.load('en_core_web_sm')
#doc = nlp(text)
# st.write(displacy.render(doc, style='dep'))
#sentence_spans = list(doc.sents)
# dep_svg = displacy.serve(sentence_spans, style="dep")
# dep_svg = displacy.render(doc, style="dep", jupyter = False,
# options = {"compact" : False,})
# st.image(dep_svg, width = 50,use_column_width=True)
#visualize_parser(doc)
#docs = [doc]
#split_sents = True
#docs = [span.as_doc() for span in doc.sents] if split_sents else [doc]
#for sent in docs:
html = render_sentence_custom(text)
# Double newlines seem to mess with the rendering
html = html.replace("\n\n", "\n")
st.write(get_svg(html), unsafe_allow_html=True)
#st.image(html, width=50, use_column_width=True)
def check_dependency(text):
tagger = SequenceTagger.load("flair/ner-english-ontonotes-fast")
nlp = spacy.load('en_core_web_lg')
doc = nlp(text)
tok_l = doc.to_json()['tokens']
# all_deps = []
all_deps = ""
sentences = list(doc.sents)
for sentence in sentences:
all_entities = []
# # ENTITIES WITH SPACY:
for entity in sentence.ents:
all_entities.append(str(entity))
# # ENTITIES WITH FLAIR:
sentence_entities = Sentence(str(sentence))
tagger.predict(sentence_entities)
for entity in sentence_entities.get_spans('ner'):
all_entities.append(entity.text)
# ENTITIES WITH XLM ROBERTA
# entities_xlm = [entity["word"] for entity in ner_model(str(sentence))]
# for entity in entities_xlm:
# all_entities.append(str(entity))
start_id = sentence.start
end_id = sentence.end
for t in tok_l:
if t["id"] < start_id or t["id"] > end_id:
continue
head = tok_l[t['head']]
if t['dep'] == 'amod':
object_here = text[t['start']:t['end']]
object_target = text[head['start']:head['end']]
# ONE NEEDS TO BE ENTITY
if (object_here in all_entities):
# all_deps.append(f"'{text[t['start']:t['end']]}' is {t['dep']} of '{text[head['start']:head['end']]}'")
all_deps = all_deps.join(str(sentence))
elif (object_target in all_entities):
# all_deps.append(f"'{text[t['start']:t['end']]}' is {t['dep']} of '{text[head['start']:head['end']]}'")
all_deps = all_deps.join(str(sentence))
else:
continue
return all_deps
with st.form("article-input"):
left_column, _ = st.columns([1, 1])
get_summary = left_column.form_submit_button("Generate summary",
help="Generate summary for the given article text")
# Listener
if get_summary:
if article_text:
with st.spinner('Generating summary...'):
# classify_comment(article_text, selected_model)
display_summary(selected_article)
else:
st.error('**Error**: No comment to classify. Please provide a comment.')
# Entity part
with st.form("Entity-part"):
left_column, _ = st.columns([1, 1])
draw_entities = left_column.form_submit_button("Draw Entities",
help="Draw Entities")
if draw_entities:
with st.spinner("Drawing entities..."):
highlight_entities(selected_article)
with st.form("Dependency-usage"):
left_column, _ = st.columns([1, 1])
parsing = left_column.form_submit_button("Dependency parsing",
help="Dependency parsing")
if parsing:
with st.spinner("Doing dependency parsing..."):
render_dependency_parsing(check_dependency(fetch_summary_contents(selected_article)))
# Results
# if 'results' in st.session_state and st.session_state.results:
# first = True
# for result in st.session_state.results[::-1]:
# if not first:
# st.markdown("---")
# st.markdown(f"Text:\n> {result['text']}")
# col_1, col_2, col_3 = st.columns([1,2,2])
# col_1.metric(label='', value=f"{result['emoji']}")
# col_2.metric(label='Label', value=f"{result['label']}")
# col_3.metric(label='Score', value=f"{result['score']:.3f}")
# st.markdown(f"Token Attribution:\n{result['tokens_with_background']}",
# unsafe_allow_html=True)
# st.caption(f"Model: {result['model_name']}")
# first = False