Spaces:
Runtime error
Runtime error
File size: 4,987 Bytes
a9ae4d6 2d9a7bb 8ae2f60 d6bb012 826b6f7 79341c6 3ee5fd7 b27a82c 826b6f7 d6bb012 79341c6 3ee5fd7 d6bb012 79341c6 ea2d98b d6bb012 826b6f7 d0ce3e0 634dfe9 826b6f7 d6bb012 634dfe9 826b6f7 d0ce3e0 3ee5fd7 d0ce3e0 ea2d98b b27a82c ac0de5c ea2d98b d0ce3e0 3ee5fd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import streamlit as st
from annotated_text import annotated_text
from io import StringIO
from transformers import AutoTokenizer, AutoModelForTokenClassification
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import plotly.express as px
from streamlit_option_menu import option_menu
st. set_page_config(layout="wide")
from transformers import pipeline
import pandas as pd
@st.cache(allow_output_mutation = True)
def init_text_summarization_model():
MODEL = 'facebook/bart-large-cnn'
pipe = pipeline("summarization", model=MODEL)
return pipe
@st.cache(allow_output_mutation = True)
def init_zsl_topic_classification():
MODEL = 'facebook/bart-large-mnli'
pipe = pipeline("zero-shot-classification", model=MODEL)
template = "This text is about {}."
return pipe, template
@st.cache(allow_output_mutation = True)
def init_zsl_topic_classification():
MODEL = 'facebook/bart-large-mnli'
pipe = pipeline("zero-shot-classification", model=MODEL)
template = "This text is about {}."
return pipe, template
@st.cache(allow_output_mutation = True)
def init_ner_pipeline():
tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
pipe = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu
return pipe
@st.cache(allow_output_mutation = True)
def init_qa_pipeline():
question_answerer_pipe = pipeline("question-answering", model='distilbert-base-cased-distilled-squad')
return question_answerer_pipe
def get_formatted_text_for_annotation(output):
colour_map = {'Sex': '#5DD75D',
'Duration': '#D92E45',
'Sign_symptom': '#793F41',
'Frequency': '#232AE7',
'Detailed_description': '#E1D8D1',
'History': '#296FB8',
'Clinical_event': '#E840A7',
'Lab_value': '#FE90C3',
'Age': '#31404C',
'Biological_structure': '#1A4B5B',
'Diagnostic_procedure': '#804E7A'}
annotated_texts = []
next_index = 0
for entity in output:
if entity['start'] == next_index:
# print("found entity")
extracted_text = text[entity['start']:entity['end']]
# print("annotated",annotated_text)
annotated_texts.append((extracted_text ,entity['entity_group'],colour_map[entity['entity_group']]))
else:
unannotated_text = text[next_index:entity['start']-1]
annotated_texts.append(unannotated_text)
extracted_text = text[entity['start']:entity['end']]
annotated_texts.append((extracted_text ,entity['entity_group'],colour_map[entity['entity_group']]))
next_index =entity['end'] +1
if next_index < len(text):
annotated_texts.append(text[next_index-1:len(text)-1])
return tuple(annotated_texts)
# Model initialization
pipeline_summarization = init_text_summarization_model()
pipeline_zsl, template = init_zsl_topic_classification()
pipeline_ner =init_ner_pipeline()
pipeline_qa = init_qa_pipeline()
st.header("Intelligent Document Automation")
def get_text_from_ocr_engine():
return "CASE: A 28-year-old previously healthy man presented with a 6-week history of palpitations. The symptoms occurred during rest, 2–3 times per week, lasted up to 30 minutes at a time and were associated with dyspnea. Except for a grade 2/6 holosystolic tricuspid regurgitation murmur (best heard at the left sternal border with inspiratory accentuation), physical examination yielded unremarkable findings."
with st.sidebar:
selected_menu = option_menu("Select Option",
["Upload Document", "Extract Text", "Summarize Document", "Extract Entities","Get Answers"],
menu_icon="cast", default_index=0)
if selected_menu == "Upload Document":
uploaded_file = st.file_uploader("Choose a file")
if uploaded_file is not None:
ocr_text = get_text_from_ocr_engine(uploaded_file)
st.write("Upload Successful")
elif selected_menu == "Extract Text":
st.write(get_text_from_ocr_engine())
elif selected_menu == "Summarize Document":
text = get_text_from_ocr_engine()
with st.spinner("Summarizing Document..."):
summary_text = pipeline_summarization(review, max_length=130, min_length=10, do_sample=False)
# Show output
st.write(summary_text[0]['summary_text'])
elif selected_menu == "Extract Entities":
text = get_text_from_ocr_engine()
output = pipeline_ner (text)
entities_text =get_formatted_text_for_annotation(output)
annotated_text(*entities_text)
elif selected_menu == "Get Answers":
st.subheader('Question')
question_text = st.text_input("Type your question")
context = get_text_from_ocr_engine()
if question_text:
result = question_answerer(question=question_text, context=context)
st.subheader('Answer')
st.text(result['answer']) |