jishnuprakash commited on
Commit
bc9855c
·
1 Parent(s): 75a0868

user interface

Browse files
Files changed (1) hide show
  1. home.py +154 -0
home.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author:jishnuprakash
3
+ """
4
+ import os
5
+ import torch
6
+ import spacy
7
+ import utils as ut
8
+ import streamlit as st
9
+ import pandas as pd
10
+ import plotly.express as px
11
+ import plotly.graph_objects as go
12
+ import pandas as pd
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ import seaborn as sns
16
+ from nltk import word_tokenize
17
+ from nltk.probability import FreqDist
18
+ from matplotlib import pyplot as plt
19
+ from nltk.corpus import stopwords
20
+ from tqdm.auto import tqdm
21
+ from datasets import load_dataset
22
+ from transformers import AutoTokenizer, AutoModel
23
+ from pytorch_lightning.metrics.functional import accuracy, f1, auroc
24
+ from sklearn.metrics import classification_report
25
+
26
+
27
+ st.set_page_config(page_title='NLP Challenge- JP', layout='wide', page_icon=':computer:')
28
+ st.set_option('deprecation.showPyplotGlobalUse', False)
29
+
30
+ #this is the header
31
+ st.markdown("<h1 style='text-align: center; color: black;'>NLP Challenge - HM Land Registry</h1>", unsafe_allow_html=True)
32
+ st.markdown("<h3 style='text-align: center; color: grey;'>Multi-label classification using BERT Transformers</h3>", unsafe_allow_html=True)
33
+ st.markdown("<div style='text-align: center''> Submission by: Jishnu Prakash Kunnanath Poduvattil | Portfolio:<a href='https://jishnuprakash.github.io/'>jishnuprakash.github.io</a> | Source Code: <a href='https://github.com/Jishnuprakash/lexGLUE_jishnuprakash'>Github</a> </div>", unsafe_allow_html=True)
34
+ st.text('')
35
+ expander = st.expander("View Description")
36
+ expander.write("""This is minimal user interface implemetation to view and interact with
37
+ results obtained from fine-tuned BERT transformers trained on LEX GLUE: ECTHR_A dataset.
38
+ Try inputing a text below and see the model predictions. You can also extract the location
39
+ and Date entities from the text using the checkbox.\\
40
+ Below, you can do the same on test data. """)
41
+
42
+
43
+ #Load trained model
44
+ @st.cache(allow_output_mutation=True)
45
+ def load_model():
46
+ trained_model = ut.LexGlueTagger.load_from_checkpoint(
47
+ os.path.join(os.getcwd(), ut.checkpoint_dir, ut.check_filename+'.ckpt'),
48
+ num_classes = ut.num_classes)
49
+ #Initialise BERT tokenizer
50
+ tokenizer = AutoTokenizer.from_pretrained(ut.bert_model)
51
+ #Set to Eval and freeze to avoid weight update
52
+ trained_model.eval()
53
+ trained_model.freeze()
54
+ test = load_dataset("lex_glue", "ecthr_a")['test']
55
+ test = ut.preprocess_data(pd.DataFrame(test))
56
+ #Load Model from Spacy
57
+ NER = spacy.load("en_core_web_sm")
58
+ return (trained_model, tokenizer, test, NER)
59
+
60
+ trained_model, tokenizer, test, NER = load_model()
61
+
62
+ st.header("Try out a text!")
63
+ with st.form('model_prediction'):
64
+ text = st.text_area("Input Text", " ".join(test.iloc[0]['text'])[:1525])
65
+ n1, n2, n3 = st.columns((0.13,0.3,0.4))
66
+ ner_check = n1.checkbox("Extract Location and Date", value=True)
67
+ predict = n2.form_submit_button("Predict")
68
+ with st.spinner("Predicting..."):
69
+ if predict:
70
+ encoding = tokenizer.encode_plus(text,
71
+ add_special_tokens=True,
72
+ max_length=512,
73
+ return_token_type_ids=False,
74
+ padding="max_length",
75
+ return_attention_mask=True,
76
+ return_tensors='pt',)
77
+ # Predict on text
78
+ _, prediction = trained_model(encoding["input_ids"], encoding["attention_mask"])
79
+ prediction = list(prediction.flatten().numpy())
80
+
81
+ final_predictions = [prediction.index(i) for i in prediction if i > ut.threshold]
82
+ if len(final_predictions)>0:
83
+ for i in final_predictions:
84
+ st.write('Violations: '+ ut.lex_classes[i] + ' : ' + str(round(prediction[i]*100, 2)) + ' %')
85
+ else:
86
+ st.write("Confidence less than 50%, Please try another text.")
87
+
88
+ if ner_check:
89
+ #Perform NER on a single text
90
+ n_text = NER(text)
91
+ loc = ''
92
+ date = ''
93
+ for word in n_text.ents:
94
+ print(word.text,word.label_)
95
+ if word.label_ == 'DATE':
96
+ date += word.text + ', '
97
+ elif word.label_ == 'GPE':
98
+ loc += word.text + ', '
99
+ loc = "None found" if len(loc)<1 else loc
100
+ date = "None found" if len(date)<1 else date
101
+ st.write("Location entities: " + loc)
102
+ st.write("Date entities: " + date)
103
+
104
+ st.header("Predict on test data")
105
+ with st.form('model_test_prediction'):
106
+ s1, s2, s3 = st.columns((0.1, 0.3, 0.6))
107
+ top = s1.number_input("Count",1, len(test), value=10)
108
+ ner_check2 = s2.checkbox("Extract Location and Date", value=True)
109
+ predict2 = s2.form_submit_button("Predict")
110
+ with st.spinner("Predicting on test data"):
111
+ if predict2:
112
+ test_dataset = ut.LexGlueDataset(test.head(top), tokenizer, max_tokens=512)
113
+
114
+ # Predict on test data
115
+ predictions = []
116
+ labels = []
117
+
118
+ for item in tqdm(test_dataset):
119
+ _ , prediction = trained_model(item["input_ids"].unsqueeze(dim=0),
120
+ item["attention_mask"].unsqueeze(dim=0))
121
+ predictions.append(prediction.flatten())
122
+ labels.append(item["labels"].int())
123
+
124
+ predictions = torch.stack(predictions)
125
+ labels = torch.stack(labels)
126
+
127
+ y_pred = predictions.numpy()
128
+ y_true = labels.numpy()
129
+
130
+ #Filter predictions
131
+ upper, lower = 1, 0
132
+ y_pred = np.where(y_pred > ut.threshold, upper, lower)
133
+ # d1, d2 = st.columns((0.6, 0.4))
134
+
135
+ #Accuracy
136
+ acc = round(float(accuracy(predictions, labels, threshold=ut.threshold))*100, 2)
137
+
138
+ out = test_dataset.data
139
+ out['predictions'] = [[list(i).index(j) for j in i if j==1] for i in y_pred]
140
+ out['labels'] = out['labels'].apply(lambda x: [ut.lex_classes[i] for i in x])
141
+ out['predictions'] = out['predictions'].apply(lambda x: [ut.lex_classes[i] for i in x])
142
+
143
+ if ner_check2:
144
+ #Perform NER on Test Dataset
145
+ out['nlp_text'] = out.text.apply(lambda x: NER(" ".join(x)))
146
+
147
+ #Extract Entities
148
+ out['location'] = out.nlp_text.apply(lambda x: set([i.text for i in x.ents if i.label_=='GPE']))
149
+ out['date'] = out.nlp_text.apply(lambda x: set([i.text for i in x.ents if i.label_=='DATE']))
150
+
151
+ st.dataframe(out.drop('nlp_text', axis=1))
152
+ else:
153
+ st.dataframe(out)
154
+ s3.metric(label ='Accuracy',value = acc, delta = '', delta_color = 'inverse')