ner-gradio / app.py
saadiiii's picture
Update app.py
4b57103
import gradio as gr
import http.client
import json
from bs4 import BeautifulSoup as bs
import re
import pymongo
import torch
import spacy
from spacy import displacy
# from pymongo import MongoClient
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("law-ai/InLegalBERT")
from transformers import BertForTokenClassification
class BertModel(torch.nn.Module):
def __init__(self):
super(BertModel, self).__init__()
self.bert = BertForTokenClassification.from_pretrained('law-ai/InLegalBERT', num_labels=14)
def forward(self, input_id, mask, label):
output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False)
return output
model_preamble = BertModel()
model_preamble = torch.load("nerbert_preamble.pt", map_location=torch.device('cpu'))
model_judgment = BertModel()
model_judgment = torch.load("nerbert.pt", map_location=torch.device('cpu'))
unique_labels_preamble = {'I-PETITIONER', 'I-COURT', 'B-COURT', 'B-JUDGE', 'I-LAWYER', 'B-RESPONDENT', 'I-JUDGE', 'B-PETITIONER', 'I-RESPONDENT', 'B-LAWYER', 'O'}
unique_labels_judgment = {'B-WITNESS', 'I-PETITIONER', 'I-JUDGE', 'B-STATUTE', 'B-OTHER_PERSON', 'B-CASE_NUMBER', 'I-ORG', 'I-PRECEDENT', 'I-RESPONDENT', 'B-PROVISION', 'O', 'I-WITNESS', 'B-ORG', 'I-COURT', 'B-RESPONDENT', 'I-DATE', 'B-GPE', 'I-CASE_NUMBER', 'B-DATE', 'B-PRECEDENT', 'I-GPE', 'B-COURT', 'B-JUDGE', 'I-STATUTE', 'B-PETITIONER', 'I-OTHER_PERSON', 'I-PROVISION'}
labels_to_ids_preamble = {k: v for v, k in enumerate(sorted(unique_labels_preamble))}
ids_to_labels_preamble = {v: k for v, k in enumerate(sorted(unique_labels_preamble))}
labels_to_ids_judgment = {k: v for v, k in enumerate(sorted(unique_labels_judgment))}
ids_to_labels_judgment = {v: k for v, k in enumerate(sorted(unique_labels_judgment))}
label_all_tokens = True
def align_word_ids(texts):
tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)
word_ids = tokenized_inputs.word_ids()
previous_word_idx = None
label_ids = []
for word_idx in word_ids:
if word_idx is None:
label_ids.append(-100)
elif word_idx != previous_word_idx:
try:
label_ids.append(1)
except:
label_ids.append(-100)
else:
try:
label_ids.append(1 if label_all_tokens else -100)
except:
label_ids.append(-100)
previous_word_idx = word_idx
return label_ids
def evaluate_one_preamble(model, sentence):
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors="pt")
mask = text['attention_mask'].to(device)
input_id = text['input_ids'].to(device)
label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device)
logits = model(input_id, mask, None)
logits_clean = logits[0][label_ids != -100]
predictions = logits_clean.argmax(dim=1).tolist()
prediction_label = [ids_to_labels_preamble[i] for i in predictions]
return (prediction_label,text)
def evaluate_one_text(model, sentence):
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors="pt")
mask = text['attention_mask'].to(device)
input_id = text['input_ids'].to(device)
label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device)
logits = model(input_id, mask, None)
logits_clean = logits[0][label_ids != -100]
predictions = logits_clean.argmax(dim=1).tolist()
prediction_label = [ids_to_labels_judgment[i] for i in predictions]
return (prediction_label,text)
def cleanhtml(raw_html):
CLEANR = re.compile('<.*?>')
cleantext = re.sub(CLEANR, '', raw_html)
return cleantext
nlp = spacy.blank("en")
def judgtext_analysis(text):
conn = http.client.HTTPSConnection("api.indiankanoon.org")
payload = "{}"
headers = {
'Authorization': 'Token ea381f5b51f9d55aaa71dfe6a90606e9b89f942a',
'Content-Type': 'application/json'
}
#Parse text and retrieve the document id
d = text.split('/')
docid = d[4]
endpoint="/doc/"+str(docid)+"/"
conn.request("POST", endpoint, payload, headers)
res = conn.getresponse()
data = res.read()
data = data.decode("utf-8")
data_dict = json.loads(data)
soup = bs(data_dict["doc"], 'html.parser')
judgment_text=""
for tag in soup.find_all(['p', 'blockquote']):
judgment_text+=(tag.text)+" "
judgment_text = cleanhtml(str(judgment_text))
preamble_text = soup.find("pre")
preamble_text = cleanhtml(str(preamble_text))
judgment_sentences = sentences = re.split(r' *[\.\?!][\'"\)\]]* *', judgment_text)
finalentities=[]
finaltext=""
labellist,text_tokenized = evaluate_one_preamble(model_preamble,preamble_text)
tokenlist = tokenizer.convert_ids_to_tokens(text_tokenized["input_ids"][0])
finallist=[]
for i in range(1,len(tokenlist)):
if(tokenlist[i]=='[SEP]'):
break
finallist.append(tokenlist[i])
finalstring=""
i=0
finallistshortened=[]
labellistshortened=[]
while(i<len(finallist)):
word=""
word+=finallist[i]
j=i+1
labellistshortened.append(labellist[i])
while(j<len(finallist) and finallist[j].startswith("##")):
word+=finallist[j][2:]
j+=1
finalstring+=word
finallistshortened.append(word)
finalstring+=" "
i=j
text=""
entities=[]
i=0
while(i<len(finallistshortened)):
word=""
start=len(text)
word+=finallistshortened[i]+" "
j=i+1
if(labellistshortened[i]=="O"):
i+=1
text+=word+" "
continue
entity = labellistshortened[i][2:]
ientity = "I-"+entity
while(j<len(finallistshortened) and labellistshortened[j]==ientity):
word+=finallistshortened[j]+ " "
j+=1
text+=word+" "
prevstart=len(finaltext)
end=len(text)-2
finalstring+=text + ". "
entities.append((entity,start,end))
finalentities.append((entity,prevstart + start,prevstart + end))
i=j
finaltext+=text + ". "
for sentence in judgment_sentences:
labellist,text_tokenized = evaluate_one_text(model_judgment,sentence)
tokenlist = tokenizer.convert_ids_to_tokens(text_tokenized["input_ids"][0])
finallist=[]
for i in range(1,len(tokenlist)):
if(tokenlist[i]=='[SEP]'):
break
finallist.append(tokenlist[i])
finalstring=""
i=0
finallistshortened=[]
labellistshortened=[]
while(i<len(finallist)):
word=""
word+=finallist[i]
j=i+1
labellistshortened.append(labellist[i])
while(j<len(finallist) and finallist[j].startswith("##")):
word+=finallist[j][2:]
j+=1
finalstring+=word
finallistshortened.append(word)
finalstring+=" "
i=j
text=""
entities=[]
i=0
while(i<len(finallistshortened)):
word=""
start=len(text)
word+=finallistshortened[i]+" "
j=i+1
if(labellistshortened[i]=="O"):
i+=1
text+=word+" "
continue
entity = labellistshortened[i][2:]
ientity = "I-"+entity
while(j<len(finallistshortened) and labellistshortened[j]==ientity):
word+=finallistshortened[j]+ " "
j+=1
text+=word+" "
prevstart=len(finaltext)
end=len(text)-2
finalstring+=text + ". "
entities.append((entity,start,end))
finalentities.append((entity,prevstart + start,prevstart + end))
i=j
finaltext+=text + ". "
doc = nlp(finaltext)
ents = []
for ee in finalentities:
ents.append(doc.char_span(ee[1], ee[2], ee[0]))
doc.ents = ents
#logic for repository
# cluster = MongoClient("mongodb+srv://testuser:test123@ner-gradio.mgng1wv.mongodb.net/?retryWrites=true&w=majority")
# db = cluster["nerdb"]
# collection = db["named_entities"]
content = displacy.render(doc, style='ent')
# extsoup = bs(content, 'html.parser')
# txtlist=[]
# entlist=[]
# for h in extsoup.findAll('div'):
# mark = h.findAll('mark')
# span = h.findAll('span')
# for i in mark:
# txt = i.find(text=True)
# txt = txt.replace("\n","");
# txt = txt.strip();
# txtlist.append(str(txt))
# for i in span:
# ent = i.find(text=True)
# entlist.append(str(ent))
# zipped = zip(txtlist, entlist)
# ziplist = list(zipped)
# post = {"api_docid":docid, "document_text":finaltext, "named_entities":ziplist, "entities_indices": finalentities}
# collection.insert_one(post)
html = displacy.render(doc, style="ent", page = True)
html = (
""
+ html
+ ""
)
pos_count = {
"char_count": len(text),
"token_count": 0,
}
pos_tokens = []
for token in doc:
pos_tokens.extend([(token.text, token.pos_), (" ", None)])
return html
demo = gr.Interface(
judgtext_analysis,
gr.Textbox(placeholder="Enter Indian Kanoon document URL here..."),
["html"],
)
demo.launch(inline=False)