NER-Demos / app.py
harshildarji's picture
update app
4a454ea
raw
history blame
7.68 kB
import warnings
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import streamlit as st
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
logging,
pipeline,
)
warnings.simplefilter(action="ignore", category=Warning)
logging.set_verbosity(logging.ERROR)
st.set_page_config(page_title="CAROLL Language Models - Demo", layout="wide")
st.markdown(
"""
<style>
body {
font-family: 'Poppins', sans-serif;
background-color: #f4f4f8;
}
.header {
background-color: rgba(220, 219, 219, 0.25);
color: #000;
padding: 5px 0;
text-align: center;
border-radius: 7px;
margin-bottom: 13px;
border-bottom: 2px solid #333;
}
#logo {
width: auto;
height: 75px;
margin-top: -15px;
margin-bottom: 15px;
}
.container {
background-color: #fff;
padding: 30px;
border-radius: 10px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
width: 100%;
max-width: 1000px;
margin: 0 auto;
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
}
.btn-primary {
background-color: #5477d1;
border: none;
transition: background-color 0.3s, transform 0.2s;
border-radius: 25px;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.08);
}
.btn-primary:hover {
background-color: #4c6cbe;
transform: translateY(-1px);
}
h2 {
font-weight: 600;
font-size: 24px;
margin-bottom: 20px;
}
h4 {
font-weight: 500;
font-size: 15px;
margin-top: 15px;
margin-bottom: 15px;
}
label {
font-weight: 500;
}
.tip {
background-color: rgba(180, 47, 109, 0.25);
padding: 7px;
border-radius: 7px;
display: inline-block;
margin-top: 15px;
margin-bottom: 15px;
}
.sec {
background-color: rgba(220, 219, 219, 0.10);
padding: 7px;
border-radius: 5px;
display: inline-block;
margin-top: 15px;
}
</style>
""",
unsafe_allow_html=True,
)
st.markdown(
"""
<div class="header">
<img src="https://raw.githubusercontent.com/ca-roll/ca-roll.github.io/release/images/logopic/caroll.png" alt="Research Group Logo" id="logo">
<h4>Demonstrating <a href="https://ca-roll.github.io/" target="_blank">CAROLL Research Group</a>'s Language Models</h4>
</div>
""",
unsafe_allow_html=True,
)
# Initialization for Legal NER
tokenizer_legal = AutoTokenizer.from_pretrained("PaDaS-Lab/gbert-legal-ner")
model_legal = AutoModelForTokenClassification.from_pretrained(
"PaDaS-Lab/gbert-legal-ner"
)
ner_legal = pipeline("ner", model=model_legal, tokenizer=tokenizer_legal)
# Initialization for GDPR Privacy Policy NER
tokenizer_gdpr = AutoTokenizer.from_pretrained("PaDaS-Lab/gdpr-privacy-policy-ner")
model_gdpr = AutoModelForTokenClassification.from_pretrained(
"PaDaS-Lab/gdpr-privacy-policy-ner"
)
ner_gdpr = pipeline("ner", model=model_gdpr, tokenizer=tokenizer_gdpr)
# Define class labels for Legal and GDPR NER models
classes_legal = {
"AN": "Lawyer",
"EUN": "European legal norm",
"GRT": "Court",
"GS": "Law",
"INN": "Institution",
"LD": "Country",
"LDS": "Landscape",
"LIT": "Legal literature",
"MRK": "Brand",
"ORG": "Organization",
"PER": "Person",
"RR": "Judge",
"RS": "Court decision",
"ST": "City",
"STR": "Street",
"UN": "Company",
"VO": "Ordinance",
"VS": "Regulation",
"VT": "Contract",
}
classes_gdpr = {
"DC": "Data Controller",
"DP": "Data Processor",
"DPO": "Data Protection Officer",
"R": "Recipient",
"TP": "Third Party",
"A": "Authority",
"DS": "Data Subject",
"DSO": "Data Source",
"RP": "Required Purpose",
"NRP": "Not-Required Purpose",
"P": "Processing",
"NPD": "Non-Personal Data",
"PD": "Personal Data",
"OM": "Organisational Measure",
"TM": "Technical Measure",
"LB": "Legal Basis",
"CONS": "Consent",
"CONT": "Contract",
"LI": "Legitimate Interest",
"ADM": "Automated Decision Making",
"RET": "Retention",
"SEU": "Scale EU",
"SNEU": "Scale Non-EU",
"RI": "Right",
"DSR15": "Art. 15 Right of access by the data subject",
"DSR16": "Art. 16 Right to rectification",
"DSR17": "Art. 17 Right to erasure (‘right to be forgotten’)",
"DSR18": "Art. 18 Right to restriction of processing",
"DSR19": "Art. 19 Notification obligation regarding rectification or erasure of personal data or restriction of processing",
"DSR20": "Art. 20 Right to data portability",
"DSR21": "Art. 21 Right to object",
"DSR22": "Art. 22 Automated individual decision-making, including profiling",
"LC": "Lodge Complaint",
}
# Extract the keys (labels) from the class dictionaries
ner_labels_legal = list(classes_legal.keys())
ner_labels_gdpr = list(classes_gdpr.keys())
# Function to generate a list of colors for visualization
def generate_colors(num_colors):
cm = plt.get_cmap("tab20")
colors = [mcolors.rgb2hex(cm(1.0 * i / num_colors)) for i in range(num_colors)]
return colors
# Function to color substrings based on NER results
def color_substrings(input_string, model_output, ner_labels, current_classes):
colors = generate_colors(len(ner_labels))
label_to_color = {
label: colors[i % len(colors)] for i, label in enumerate(ner_labels)
}
last_end = 0
html_output = ""
for entity in sorted(model_output, key=lambda x: x["start"]):
start, end, label = entity["start"], entity["end"], entity["label"]
html_output += input_string[last_end:start]
tooltip = current_classes.get(label, "")
html_output += f'<span style="color: {label_to_color.get(label)}; font-weight: bold;" title="{tooltip}">{input_string[start:end]}</span>'
last_end = end
html_output += input_string[last_end:]
return html_output
st.title("CAROLL Language Models - Demo")
st.markdown("<hr>", unsafe_allow_html=True)
test_sentence = st.text_area("Enter Text:", height=200)
model_choice = st.selectbox(
"Choose a model:", ["Legal NER", "GDPR Privacy Policy NER"], index=0
)
if st.button("Analyze"):
if model_choice == "Legal NER":
ner_model = ner_legal
current_classes = classes_legal
current_ner_labels = ner_labels_legal
else:
ner_model = ner_gdpr
current_classes = classes_gdpr
current_ner_labels = ner_labels_gdpr
results = ner_model(test_sentence)
processed_results = [
{
"start": result["start"],
"end": result["end"],
"label": result["entity"].split("-")[-1],
}
for result in results
]
colored_html = color_substrings(
test_sentence, processed_results, current_ner_labels, current_classes
)
st.markdown(
'<div class="sec"><strong>Analyzed text</strong></div><br><br>{}<br><br>'.format(
colored_html
),
unsafe_allow_html=True,
)
st.markdown(
'<div class="tip"><strong>Tip:</strong> Hover over the colored words to see its class.</div>',
unsafe_allow_html=True,
)