import warnings
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import streamlit as st
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
logging,
pipeline,
)
warnings.simplefilter(action="ignore", category=Warning)
logging.set_verbosity(logging.ERROR)
st.set_page_config(page_title="CAROLL Language Models - Demo", layout="wide")
st.markdown(
"""
""",
unsafe_allow_html=True,
)
st.markdown(
"""
""",
unsafe_allow_html=True,
)
# Initialization for Legal NER
tokenizer_legal = AutoTokenizer.from_pretrained("PaDaS-Lab/gbert-legal-ner")
model_legal = AutoModelForTokenClassification.from_pretrained(
"PaDaS-Lab/gbert-legal-ner"
)
ner_legal = pipeline("ner", model=model_legal, tokenizer=tokenizer_legal)
# Initialization for GDPR Privacy Policy NER
tokenizer_gdpr = AutoTokenizer.from_pretrained("PaDaS-Lab/gdpr-privacy-policy-ner")
model_gdpr = AutoModelForTokenClassification.from_pretrained(
"PaDaS-Lab/gdpr-privacy-policy-ner"
)
ner_gdpr = pipeline("ner", model=model_gdpr, tokenizer=tokenizer_gdpr)
# Define class labels for Legal and GDPR NER models
classes_legal = {
"AN": "Lawyer",
"EUN": "European legal norm",
"GRT": "Court",
"GS": "Law",
"INN": "Institution",
"LD": "Country",
"LDS": "Landscape",
"LIT": "Legal literature",
"MRK": "Brand",
"ORG": "Organization",
"PER": "Person",
"RR": "Judge",
"RS": "Court decision",
"ST": "City",
"STR": "Street",
"UN": "Company",
"VO": "Ordinance",
"VS": "Regulation",
"VT": "Contract",
}
classes_gdpr = {
"DC": "Data Controller",
"DP": "Data Processor",
"DPO": "Data Protection Officer",
"R": "Recipient",
"TP": "Third Party",
"A": "Authority",
"DS": "Data Subject",
"DSO": "Data Source",
"RP": "Required Purpose",
"NRP": "Not-Required Purpose",
"P": "Processing",
"NPD": "Non-Personal Data",
"PD": "Personal Data",
"OM": "Organisational Measure",
"TM": "Technical Measure",
"LB": "Legal Basis",
"CONS": "Consent",
"CONT": "Contract",
"LI": "Legitimate Interest",
"ADM": "Automated Decision Making",
"RET": "Retention",
"SEU": "Scale EU",
"SNEU": "Scale Non-EU",
"RI": "Right",
"DSR15": "Art. 15 Right of access by the data subject",
"DSR16": "Art. 16 Right to rectification",
"DSR17": "Art. 17 Right to erasure (‘right to be forgotten’)",
"DSR18": "Art. 18 Right to restriction of processing",
"DSR19": "Art. 19 Notification obligation regarding rectification or erasure of personal data or restriction of processing",
"DSR20": "Art. 20 Right to data portability",
"DSR21": "Art. 21 Right to object",
"DSR22": "Art. 22 Automated individual decision-making, including profiling",
"LC": "Lodge Complaint",
}
# Extract the keys (labels) from the class dictionaries
ner_labels_legal = list(classes_legal.keys())
ner_labels_gdpr = list(classes_gdpr.keys())
# Function to generate a list of colors for visualization
def generate_colors(num_colors):
cm = plt.get_cmap("tab20")
colors = [mcolors.rgb2hex(cm(1.0 * i / num_colors)) for i in range(num_colors)]
return colors
# Function to color substrings based on NER results
def color_substrings(input_string, model_output, ner_labels, current_classes):
colors = generate_colors(len(ner_labels))
label_to_color = {
label: colors[i % len(colors)] for i, label in enumerate(ner_labels)
}
last_end = 0
html_output = ""
for entity in sorted(model_output, key=lambda x: x["start"]):
start, end, label = entity["start"], entity["end"], entity["label"]
html_output += input_string[last_end:start]
tooltip = current_classes.get(label, "")
html_output += f'{input_string[start:end]}'
last_end = end
html_output += input_string[last_end:]
return html_output
st.title("CAROLL Language Models - Demo")
st.markdown("
", unsafe_allow_html=True)
test_sentence = st.text_area("Enter Text:", height=200)
model_choice = st.selectbox(
"Choose a model:", ["Legal NER", "GDPR Privacy Policy NER"], index=0
)
if st.button("Analyze"):
if model_choice == "Legal NER":
ner_model = ner_legal
current_classes = classes_legal
current_ner_labels = ner_labels_legal
else:
ner_model = ner_gdpr
current_classes = classes_gdpr
current_ner_labels = ner_labels_gdpr
results = ner_model(test_sentence)
processed_results = [
{
"start": result["start"],
"end": result["end"],
"label": result["entity"].split("-")[-1],
}
for result in results
]
colored_html = color_substrings(
test_sentence, processed_results, current_ner_labels, current_classes
)
st.markdown(
'Analyzed text
{}
'.format(
colored_html
),
unsafe_allow_html=True,
)
st.markdown(
'Tip: Hover over the colored words to see its class.
',
unsafe_allow_html=True,
)