Spaces:
Runtime error
Runtime error
import streamlit as st | |
from annotated_text import annotated_text | |
from io import StringIO | |
from transformers import AutoTokenizer, AutoModelForTokenClassification | |
import os | |
os.environ['KMP_DUPLICATE_LIB_OK']='True' | |
import plotly.express as px | |
from streamlit_option_menu import option_menu | |
st. set_page_config(layout="wide") | |
from transformers import pipeline | |
import pandas as pd | |
def init_text_summarization_model(): | |
MODEL = 'facebook/bart-large-cnn' | |
pipe = pipeline("summarization", model=MODEL) | |
return pipe | |
def init_zsl_topic_classification(): | |
MODEL = 'facebook/bart-large-mnli' | |
pipe = pipeline("zero-shot-classification", model=MODEL) | |
template = "This text is about {}." | |
return pipe, template | |
def init_zsl_topic_classification(): | |
MODEL = 'facebook/bart-large-mnli' | |
pipe = pipeline("zero-shot-classification", model=MODEL) | |
template = "This text is about {}." | |
return pipe, template | |
def init_ner_pipeline(): | |
tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") | |
model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") | |
pipe = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu | |
return pipe | |
def init_qa_pipeline(): | |
question_answerer_pipe = pipeline("question-answering", model='distilbert-base-cased-distilled-squad') | |
return question_answerer_pipe | |
def get_formatted_text_for_annotation(output): | |
colour_map = {'Sex': '#5DD75D', | |
'Duration': '#D92E45', | |
'Sign_symptom': '#793F41', | |
'Frequency': '#232AE7', | |
'Detailed_description': '#E1D8D1', | |
'History': '#296FB8', | |
'Clinical_event': '#E840A7', | |
'Lab_value': '#FE90C3', | |
'Age': '#31404C', | |
'Biological_structure': '#1A4B5B', | |
'Diagnostic_procedure': '#804E7A'} | |
annotated_texts = [] | |
next_index = 0 | |
for entity in output: | |
if entity['start'] == next_index: | |
# print("found entity") | |
extracted_text = text[entity['start']:entity['end']] | |
# print("annotated",annotated_text) | |
annotated_texts.append((extracted_text ,entity['entity_group'],colour_map[entity['entity_group']])) | |
else: | |
unannotated_text = text[next_index:entity['start']-1] | |
annotated_texts.append(unannotated_text) | |
extracted_text = text[entity['start']:entity['end']] | |
annotated_texts.append((extracted_text ,entity['entity_group'],colour_map[entity['entity_group']])) | |
next_index =entity['end'] +1 | |
if next_index < len(text): | |
annotated_texts.append(text[next_index-1:len(text)-1]) | |
return tuple(annotated_texts) | |
# Model initialization | |
pipeline_summarization = init_text_summarization_model() | |
pipeline_zsl, template = init_zsl_topic_classification() | |
pipeline_ner =init_ner_pipeline() | |
pipeline_qa = init_qa_pipeline() | |
st.header("Intelligent Document Automation") | |
def get_text_from_ocr_engine(): | |
return "CASE: A 28-year-old previously healthy man presented with a 6-week history of palpitations. The symptoms occurred during rest, 2β3 times per week, lasted up to 30 minutes at a time and were associated with dyspnea. Except for a grade 2/6 holosystolic tricuspid regurgitation murmur (best heard at the left sternal border with inspiratory accentuation), physical examination yielded unremarkable findings." | |
with st.sidebar: | |
selected_menu = option_menu("Select Option", | |
["Upload Document", "Extract Text", "Summarize Document", "Extract Entities","Get Answers"], | |
menu_icon="cast", default_index=0) | |
if selected_menu == "Upload Document": | |
uploaded_file = st.file_uploader("Choose a file") | |
if uploaded_file is not None: | |
ocr_text = get_text_from_ocr_engine(uploaded_file) | |
st.write("Upload Successful") | |
elif selected_menu == "Extract Text": | |
st.write(get_text_from_ocr_engine()) | |
elif selected_menu == "Summarize Document": | |
text = get_text_from_ocr_engine() | |
with st.spinner("Summarizing Document..."): | |
summary_text = pipeline_summarization(review, max_length=130, min_length=10, do_sample=False) | |
# Show output | |
st.write(summary_text[0]['summary_text']) | |
elif selected_menu == "Extract Entities": | |
text = get_text_from_ocr_engine() | |
output = pipeline_ner (text) | |
entities_text =get_formatted_text_for_annotation(output) | |
annotated_text(*entities_text) | |
elif selected_menu == "Get Answers": | |
st.subheader('Question') | |
question_text = st.text_input("Type your question") | |
context = get_text_from_ocr_engine() | |
if question_text: | |
result = question_answerer(question=question_text, context=context) | |
st.subheader('Answer') | |
st.text(result['answer']) |