lex_glue_ecthrA / app.py
jishnuprakash's picture
column width
74b1387
raw
history blame
6.89 kB
"""
@author:jishnuprakash
"""
import nltk
nltk.download('stopwords')
import os
import torch
import spacy
import utils as ut
import streamlit as st
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from nltk import word_tokenize
from nltk.probability import FreqDist
from matplotlib import pyplot as plt
from nltk.corpus import stopwords
from tqdm.auto import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from pytorch_lightning.metrics.functional import accuracy, f1, auroc
from sklearn.metrics import classification_report
st.set_page_config(page_title='NLP Challenge- JP', layout='wide', page_icon=':computer:')
st.set_option('deprecation.showPyplotGlobalUse', False)
#this is the header
st.markdown("<h1 style='text-align: center; color: black;'>NLP Challenge - HM Land Registry</h1>", unsafe_allow_html=True)
st.markdown("<h3 style='text-align: center; color: grey;'>Multi-label classification using BERT Transformers</h3>", unsafe_allow_html=True)
st.markdown("<div style='text-align: center''> Submission by: Jishnu Prakash Kunnanath Poduvattil | Portfolio:<a href='https://jishnuprakash.github.io/'>jishnuprakash.github.io</a> | Source Code: <a href='https://github.com/Jishnuprakash/lexGLUE_jishnuprakash'>Github</a> </div>", unsafe_allow_html=True)
st.text('')
expander = st.expander("View Description")
expander.write("""This is minimal user interface implemetation to view and interact with
results obtained from fine-tuned BERT transformers trained on LEX GLUE: ECTHR_A dataset.
Try inputing a text below and see the model predictions. You can also extract the location
and Date entities from the text using the checkbox.\\
Below, you can do the same on test data. """)
#Load trained model
@st.cache(allow_output_mutation=True)
def load_model():
trained_model = ut.LexGlueTagger.load_from_checkpoint(ut.check_filename+'.ckpt', num_classes = ut.num_classes)
#Initialise BERT tokenizer
tokenizer = AutoTokenizer.from_pretrained(ut.bert_model)
#Set to Eval and freeze to avoid weight update
trained_model.eval()
trained_model.freeze()
test = load_dataset("lex_glue", "ecthr_a")['test']
test = ut.preprocess_data(pd.DataFrame(test))
#Load Model from Spacy
NER = spacy.load("en_core_web_sm")
return (trained_model, tokenizer, test, NER)
trained_model, tokenizer, test, NER = load_model()
st.header("Try out a text!")
with st.form('model_prediction'):
text = st.text_area("Input Text", " ".join(test.iloc[0]['text'])[:1525])
n1, n2, n3 = st.columns((0.2,0.4,0.4))
ner_check = n1.checkbox("Extract Location and Date", value=True)
predict = n2.form_submit_button("Predict")
with st.spinner("Predicting..."):
if predict:
encoding = tokenizer.encode_plus(text,
add_special_tokens=True,
max_length=512,
return_token_type_ids=False,
padding="max_length",
return_attention_mask=True,
return_tensors='pt',)
# Predict on text
_, prediction = trained_model(encoding["input_ids"], encoding["attention_mask"])
prediction = list(prediction.flatten().numpy())
final_predictions = [prediction.index(i) for i in prediction if i > ut.threshold]
if len(final_predictions)>0:
for i in final_predictions:
st.write('Violations: '+ ut.lex_classes[i] + ' : ' + str(round(prediction[i]*100, 2)) + ' %')
else:
st.write("Confidence less than 50%, Please try another text.")
if ner_check:
#Perform NER on a single text
n_text = NER(text)
loc = ''
date = ''
for word in n_text.ents:
print(word.text,word.label_)
if word.label_ == 'DATE':
date += word.text + ', '
elif word.label_ == 'GPE':
loc += word.text + ', '
loc = "None found" if len(loc)<1 else loc
date = "None found" if len(date)<1 else date
st.write("Location entities: " + loc)
st.write("Date entities: " + date)
st.header("Predict on test data")
with st.form('model_test_prediction'):
s1, s2, s3 = st.columns((0.2, 0.4, 0.4))
top = s1.number_input("Count",1, len(test), value=10)
ner_check2 = s2.checkbox("Extract Location and Date", value=True)
predict2 = s2.form_submit_button("Predict")
with st.spinner("Predicting on test data"):
if predict2:
test_dataset = ut.LexGlueDataset(test.head(top), tokenizer, max_tokens=512)
# Predict on test data
predictions = []
labels = []
for item in tqdm(test_dataset):
_ , prediction = trained_model(item["input_ids"].unsqueeze(dim=0),
item["attention_mask"].unsqueeze(dim=0))
predictions.append(prediction.flatten())
labels.append(item["labels"].int())
predictions = torch.stack(predictions)
labels = torch.stack(labels)
y_pred = predictions.numpy()
y_true = labels.numpy()
#Filter predictions
upper, lower = 1, 0
y_pred = np.where(y_pred > ut.threshold, upper, lower)
# d1, d2 = st.columns((0.6, 0.4))
#Accuracy
acc = round(float(accuracy(predictions, labels, threshold=ut.threshold))*100, 2)
out = test_dataset.data
out['predictions'] = [[list(i).index(j) for j in i if j==1] for i in y_pred]
out['labels'] = out['labels'].apply(lambda x: [ut.lex_classes[i] for i in x])
out['predictions'] = out['predictions'].apply(lambda x: [ut.lex_classes[i] for i in x])
if ner_check2:
#Perform NER on Test Dataset
out['nlp_text'] = out.text.apply(lambda x: NER(" ".join(x)))
#Extract Entities
out['location'] = out.nlp_text.apply(lambda x: set([i.text for i in x.ents if i.label_=='GPE']))
out['date'] = out.nlp_text.apply(lambda x: set([i.text for i in x.ents if i.label_=='DATE']))
st.dataframe(out.drop('nlp_text', axis=1))
else:
st.dataframe(out)
s3.metric(label ='Accuracy',value = acc, delta = '', delta_color = 'inverse')