Spaces:
Build error
Build error
""" | |
@author:jishnuprakash | |
""" | |
import nltk | |
nltk.download('stopwords') | |
import os | |
import torch | |
import spacy | |
import utils as ut | |
import streamlit as st | |
import pandas as pd | |
import plotly.express as px | |
import plotly.graph_objects as go | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from spacy import displacy | |
from nltk import word_tokenize | |
from nltk.probability import FreqDist | |
from matplotlib import pyplot as plt | |
from nltk.corpus import stopwords | |
from tqdm.auto import tqdm | |
from datasets import load_dataset | |
from transformers import AutoTokenizer, AutoModel | |
from pytorch_lightning.metrics.functional import accuracy, f1, auroc | |
from sklearn.metrics import classification_report | |
st.set_page_config(page_title='NLP Challenge- JP', layout='wide', page_icon=':computer:') | |
st.set_option('deprecation.showPyplotGlobalUse', False) | |
#this is the header | |
st.markdown("<h1 style='text-align: center; color: black;'>NLP Challenge - HM Land Registry</h1>", unsafe_allow_html=True) | |
st.markdown("<h3 style='text-align: center; color: grey;'>Multi-label classification using BERT Transformers</h3>", unsafe_allow_html=True) | |
st.markdown("<div style='text-align: center''> Submission by: Jishnu Prakash Kunnanath Poduvattil | Portfolio:<a href='https://jishnuprakash.github.io/'>jishnuprakash.github.io</a> | Source Code: <a href='https://github.com/Jishnuprakash/lexGLUE_jishnuprakash'>Github</a> </div>", unsafe_allow_html=True) | |
st.text('') | |
expander = st.expander("View Description") | |
expander.write("""This is minimal user interface implemetation to view and interact with | |
results obtained from fine-tuned BERT transformers trained on LEX GLUE: ECTHR_A dataset. | |
Try inputing a text below and see the model predictions. You can also extract the location | |
and Date entities from the text using the checkbox.\\ | |
Below, you can do the same on test data. """) | |
#Load trained model | |
def load_model(): | |
trained_model = ut.LexGlueTagger.load_from_checkpoint(ut.check_filename+'.ckpt', num_classes = ut.num_classes) | |
#Initialise BERT tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(ut.bert_model) | |
#Set to Eval and freeze to avoid weight update | |
trained_model.eval() | |
trained_model.freeze() | |
test = load_dataset("lex_glue", "ecthr_a")['test'] | |
test = ut.preprocess_data(pd.DataFrame(test)) | |
#Load Model from Spacy | |
NER = spacy.load("en_core_web_sm") | |
return (trained_model, tokenizer, test, NER) | |
trained_model, tokenizer, test, NER = load_model() | |
st.header("Try out a text!") | |
with st.form('model_prediction'): | |
text = st.text_area("Input Text", " ".join(test.iloc[0]['text'])[:1525]) | |
n1, n2, n3 = st.columns((0.2,0.4,0.4)) | |
ner_check = n1.checkbox("Extract Location and Date", value=True) | |
predict = n2.form_submit_button("Predict") | |
with st.spinner("Predicting..."): | |
if predict: | |
encoding = tokenizer.encode_plus(text, | |
add_special_tokens=True, | |
max_length=512, | |
return_token_type_ids=False, | |
padding="max_length", | |
return_attention_mask=True, | |
return_tensors='pt',) | |
# Predict on text | |
_, prediction = trained_model(encoding["input_ids"], encoding["attention_mask"]) | |
prediction = list(prediction.flatten().numpy()) | |
final_predictions = [prediction.index(i) for i in prediction if i > ut.threshold] | |
if len(final_predictions)>0: | |
for i in final_predictions: | |
st.write('Violations: '+ ut.lex_classes[i] + ' : ' + str(round(prediction[i]*100, 2)) + ' %') | |
else: | |
st.write("Confidence less than 50%, Please try another text.") | |
if ner_check: | |
#Perform NER on a single text | |
n_text = NER(text) | |
loc = '' | |
date = '' | |
for word in n_text.ents: | |
print(word.text,word.label_) | |
if word.label_ == 'DATE': | |
date += word.text + ', ' | |
elif word.label_ == 'GPE': | |
loc += word.text + ', ' | |
loc = "None found" if len(loc)<1 else loc | |
date = "None found" if len(date)<1 else date | |
st.write("Location entities: " + loc) | |
st.write("Date entities: " + date) | |
#Display entities | |
st.write("All Entities-") | |
ent_html = displacy.render(n_text, style="ent", jupyter=False) | |
# Display the entity visualization in the browser: | |
st.markdown(ent_html, unsafe_allow_html=True) | |
st.header("Predict on test data") | |
with st.form('model_test_prediction'): | |
s1, s2, s3 = st.columns((0.2, 0.4, 0.4)) | |
top = s1.number_input("Count",1, len(test), value=10) | |
ner_check2 = s2.checkbox("Extract Location and Date", value=True) | |
predict2 = s2.form_submit_button("Predict") | |
with st.spinner("Predicting on test data"): | |
if predict2: | |
test_dataset = ut.LexGlueDataset(test.head(top), tokenizer, max_tokens=512) | |
# Predict on test data | |
predictions = [] | |
labels = [] | |
for item in tqdm(test_dataset): | |
_ , prediction = trained_model(item["input_ids"].unsqueeze(dim=0), | |
item["attention_mask"].unsqueeze(dim=0)) | |
predictions.append(prediction.flatten()) | |
labels.append(item["labels"].int()) | |
predictions = torch.stack(predictions) | |
labels = torch.stack(labels) | |
y_pred = predictions.numpy() | |
y_true = labels.numpy() | |
#Filter predictions | |
upper, lower = 1, 0 | |
y_pred = np.where(y_pred > ut.threshold, upper, lower) | |
# d1, d2 = st.columns((0.6, 0.4)) | |
#Accuracy | |
acc = round(float(accuracy(predictions, labels, threshold=ut.threshold))*100, 2) | |
out = test_dataset.data | |
out['predictions'] = [[list(i).index(j) for j in i if j==1] for i in y_pred] | |
out['labels'] = out['labels'].apply(lambda x: [ut.lex_classes[i] for i in x]) | |
out['predictions'] = out['predictions'].apply(lambda x: [ut.lex_classes[i] for i in x]) | |
if ner_check2: | |
#Perform NER on Test Dataset | |
out['nlp_text'] = out.text.apply(lambda x: NER(" ".join(x))) | |
#Extract Entities | |
out['location'] = out.nlp_text.apply(lambda x: set([i.text for i in x.ents if i.label_=='GPE'])) | |
out['date'] = out.nlp_text.apply(lambda x: set([i.text for i in x.ents if i.label_=='DATE'])) | |
st.dataframe(out.drop('nlp_text', axis=1)) | |
else: | |
st.dataframe(out) | |
s3.metric(label ='Accuracy',value = acc, delta = '', delta_color = 'inverse') |