Sasidhar's picture
Update app.py
3ee5fd7
raw
history blame
4.99 kB
import streamlit as st
from annotated_text import annotated_text
from io import StringIO
from transformers import AutoTokenizer, AutoModelForTokenClassification
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import plotly.express as px
from streamlit_option_menu import option_menu
st. set_page_config(layout="wide")
from transformers import pipeline
import pandas as pd
@st.cache(allow_output_mutation = True)
def init_text_summarization_model():
MODEL = 'facebook/bart-large-cnn'
pipe = pipeline("summarization", model=MODEL)
return pipe
@st.cache(allow_output_mutation = True)
def init_zsl_topic_classification():
MODEL = 'facebook/bart-large-mnli'
pipe = pipeline("zero-shot-classification", model=MODEL)
template = "This text is about {}."
return pipe, template
@st.cache(allow_output_mutation = True)
def init_zsl_topic_classification():
MODEL = 'facebook/bart-large-mnli'
pipe = pipeline("zero-shot-classification", model=MODEL)
template = "This text is about {}."
return pipe, template
@st.cache(allow_output_mutation = True)
def init_ner_pipeline():
tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
pipe = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu
return pipe
@st.cache(allow_output_mutation = True)
def init_qa_pipeline():
question_answerer_pipe = pipeline("question-answering", model='distilbert-base-cased-distilled-squad')
return question_answerer_pipe
def get_formatted_text_for_annotation(output):
colour_map = {'Sex': '#5DD75D',
'Duration': '#D92E45',
'Sign_symptom': '#793F41',
'Frequency': '#232AE7',
'Detailed_description': '#E1D8D1',
'History': '#296FB8',
'Clinical_event': '#E840A7',
'Lab_value': '#FE90C3',
'Age': '#31404C',
'Biological_structure': '#1A4B5B',
'Diagnostic_procedure': '#804E7A'}
annotated_texts = []
next_index = 0
for entity in output:
if entity['start'] == next_index:
# print("found entity")
extracted_text = text[entity['start']:entity['end']]
# print("annotated",annotated_text)
annotated_texts.append((extracted_text ,entity['entity_group'],colour_map[entity['entity_group']]))
else:
unannotated_text = text[next_index:entity['start']-1]
annotated_texts.append(unannotated_text)
extracted_text = text[entity['start']:entity['end']]
annotated_texts.append((extracted_text ,entity['entity_group'],colour_map[entity['entity_group']]))
next_index =entity['end'] +1
if next_index < len(text):
annotated_texts.append(text[next_index-1:len(text)-1])
return tuple(annotated_texts)
# Model initialization
pipeline_summarization = init_text_summarization_model()
pipeline_zsl, template = init_zsl_topic_classification()
pipeline_ner =init_ner_pipeline()
pipeline_qa = init_qa_pipeline()
st.header("Intelligent Document Automation")
def get_text_from_ocr_engine():
return "CASE: A 28-year-old previously healthy man presented with a 6-week history of palpitations. The symptoms occurred during rest, 2–3 times per week, lasted up to 30 minutes at a time and were associated with dyspnea. Except for a grade 2/6 holosystolic tricuspid regurgitation murmur (best heard at the left sternal border with inspiratory accentuation), physical examination yielded unremarkable findings."
with st.sidebar:
selected_menu = option_menu("Select Option",
["Upload Document", "Extract Text", "Summarize Document", "Extract Entities","Get Answers"],
menu_icon="cast", default_index=0)
if selected_menu == "Upload Document":
uploaded_file = st.file_uploader("Choose a file")
if uploaded_file is not None:
ocr_text = get_text_from_ocr_engine(uploaded_file)
st.write("Upload Successful")
elif selected_menu == "Extract Text":
st.write(get_text_from_ocr_engine())
elif selected_menu == "Summarize Document":
text = get_text_from_ocr_engine()
with st.spinner("Summarizing Document..."):
summary_text = pipeline_summarization(review, max_length=130, min_length=10, do_sample=False)
# Show output
st.write(summary_text[0]['summary_text'])
elif selected_menu == "Extract Entities":
text = get_text_from_ocr_engine()
output = pipeline_ner (text)
entities_text =get_formatted_text_for_annotation(output)
annotated_text(*entities_text)
elif selected_menu == "Get Answers":
st.subheader('Question')
question_text = st.text_input("Type your question")
context = get_text_from_ocr_engine()
if question_text:
result = question_answerer(question=question_text, context=context)
st.subheader('Answer')
st.text(result['answer'])