""" @author:jishnuprakash """ import nltk nltk.download('stopwords') import os import torch import spacy import utils as ut import streamlit as st import pandas as pd import plotly.express as px import plotly.graph_objects as go import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from spacy import displacy from nltk import word_tokenize from nltk.probability import FreqDist from matplotlib import pyplot as plt from nltk.corpus import stopwords from tqdm.auto import tqdm from datasets import load_dataset from transformers import AutoTokenizer, AutoModel from pytorch_lightning.metrics.functional import accuracy, f1, auroc from sklearn.metrics import classification_report st.set_page_config(page_title='NLP Challenge- JP', layout='wide', page_icon=':computer:') st.set_option('deprecation.showPyplotGlobalUse', False) #this is the header st.markdown("

NLP Challenge - HM Land Registry

", unsafe_allow_html=True) st.markdown("

Multi-label classification using BERT Transformers

", unsafe_allow_html=True) st.markdown("
Submission by: Jishnu Prakash Kunnanath Poduvattil | Portfolio:jishnuprakash.github.io | Source Code: Github
", unsafe_allow_html=True) st.text('') expander = st.expander("View Description") expander.write("""This is minimal user interface implemetation to view and interact with results obtained from fine-tuned BERT transformers trained on LEX GLUE: ECTHR_A dataset. Try inputing a text below and see the model predictions. You can also extract the location and Date entities from the text using the checkbox.\\ Below, you can do the same on test data. """) #Load trained model @st.cache(allow_output_mutation=True) def load_model(): trained_model = ut.LexGlueTagger.load_from_checkpoint(ut.check_filename+'.ckpt', num_classes = ut.num_classes) #Initialise BERT tokenizer tokenizer = AutoTokenizer.from_pretrained(ut.bert_model) #Set to Eval and freeze to avoid weight update trained_model.eval() trained_model.freeze() test = load_dataset("lex_glue", "ecthr_a")['test'] test = ut.preprocess_data(pd.DataFrame(test)) #Load Model from Spacy NER = spacy.load("en_core_web_sm") return (trained_model, tokenizer, test, NER) trained_model, tokenizer, test, NER = load_model() st.header("Try out a text!") with st.form('model_prediction'): text = st.text_area("Input Text", test.iloc[0]['text'][20]) n1, n2, n3 = st.columns((0.2,0.4,0.4)) ner_check = n1.checkbox("Extract Location and Date", value=True) predict = n2.form_submit_button("Predict") with st.spinner("Predicting..."): if predict: encoding = tokenizer.encode_plus(text, add_special_tokens=True, max_length=512, return_token_type_ids=False, padding="max_length", return_attention_mask=True, return_tensors='pt',) # Predict on text _, prediction = trained_model(encoding["input_ids"], encoding["attention_mask"]) prediction = list(prediction.flatten().numpy()) final_predictions = [prediction.index(i) for i in prediction if i > ut.threshold] if len(final_predictions)>0: for i in final_predictions: st.write('Violations: '+ ut.lex_classes[i] + ' : ' + str(round(prediction[i]*100, 2)) + ' %') else: st.write("Confidence less than 50%, Please try another text.") if ner_check: #Perform NER on a single text n_text = NER(text) loc = [] date = [] for word in n_text.ents: print(word.text,word.label_) if word.label_ == 'DATE': date.append(word.text) elif word.label_ == 'GPE': loc.append(word.text) loc = list(set(loc)) date = list(set(date)) loc = "None found" if len(loc)==0 else loc date = "None found" if len(date)==0 else date st.write("Location entities: " + ",".join(loc)) st.write("Date entities: " + ",".join(date)) #Display entities st.write("All Entities-") ent_html = displacy.render(n_text, style="ent", jupyter=False) # Display the entity visualization in the browser: st.markdown(ent_html, unsafe_allow_html=True) st.header("Predict on test data") with st.form('model_test_prediction'): s1, s2, s3 = st.columns((0.2, 0.4, 0.4)) top = s1.number_input("Count",1, len(test), value=10) ner_check2 = s2.checkbox("Extract Location and Date", value=True) predict2 = s2.form_submit_button("Predict") with st.spinner("Predicting on test data"): if predict2: test_dataset = ut.LexGlueDataset(test.head(top), tokenizer, max_tokens=512) # Predict on test data predictions = [] labels = [] for item in tqdm(test_dataset): _ , prediction = trained_model(item["input_ids"].unsqueeze(dim=0), item["attention_mask"].unsqueeze(dim=0)) predictions.append(prediction.flatten()) labels.append(item["labels"].int()) predictions = torch.stack(predictions) labels = torch.stack(labels) y_pred = predictions.numpy() y_true = labels.numpy() #Filter predictions upper, lower = 1, 0 y_pred = np.where(y_pred > ut.threshold, upper, lower) # d1, d2 = st.columns((0.6, 0.4)) #Accuracy acc = round(float(accuracy(predictions, labels, threshold=ut.threshold))*100, 2) out = test_dataset.data out['predictions'] = [[list(i).index(j) for j in i if j==1] for i in y_pred] out['labels'] = out['labels'].apply(lambda x: [ut.lex_classes[i] for i in x]) out['predictions'] = out['predictions'].apply(lambda x: [ut.lex_classes[i] for i in x]) if ner_check2: #Perform NER on Test Dataset out['nlp_text'] = out.text.apply(lambda x: NER(" ".join(x))) #Extract Entities out['location'] = out.nlp_text.apply(lambda x: set([i.text for i in x.ents if i.label_=='GPE'])) out['date'] = out.nlp_text.apply(lambda x: set([i.text for i in x.ents if i.label_=='DATE'])) st.dataframe(out.drop('nlp_text', axis=1)) else: st.dataframe(out) s3.metric(label ='Accuracy',value = acc, delta = '', delta_color = 'inverse')