File size: 4,987 Bytes
a9ae4d6
 
2d9a7bb
8ae2f60
d6bb012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
826b6f7
 
 
 
 
 
79341c6
 
 
 
 
 
 
 
3ee5fd7
 
 
 
b27a82c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
826b6f7
d6bb012
 
 
79341c6
3ee5fd7
d6bb012
 
 
79341c6
ea2d98b
d6bb012
826b6f7
d0ce3e0
 
634dfe9
826b6f7
d6bb012
634dfe9
 
826b6f7
 
 
d0ce3e0
 
 
3ee5fd7
 
 
 
 
 
d0ce3e0
ea2d98b
 
b27a82c
ac0de5c
ea2d98b
d0ce3e0
3ee5fd7
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import streamlit as st
from annotated_text import annotated_text
from io import StringIO
from transformers import AutoTokenizer, AutoModelForTokenClassification
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

import plotly.express as px
from streamlit_option_menu import option_menu

st. set_page_config(layout="wide")

from transformers import pipeline
import pandas as pd

@st.cache(allow_output_mutation = True)
def init_text_summarization_model():
    MODEL = 'facebook/bart-large-cnn'
    pipe = pipeline("summarization", model=MODEL)
    return pipe

@st.cache(allow_output_mutation = True)
def init_zsl_topic_classification():
    MODEL = 'facebook/bart-large-mnli'
    pipe = pipeline("zero-shot-classification", model=MODEL)
    template = "This text is about {}."
    return pipe, template

@st.cache(allow_output_mutation = True)
def init_zsl_topic_classification():
    MODEL = 'facebook/bart-large-mnli'
    pipe = pipeline("zero-shot-classification", model=MODEL)
    template = "This text is about {}."
    return pipe, template

@st.cache(allow_output_mutation = True)
def init_ner_pipeline():
    tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
    model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
    pipe = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu
    return pipe

@st.cache(allow_output_mutation = True)
def init_qa_pipeline():
    question_answerer_pipe = pipeline("question-answering", model='distilbert-base-cased-distilled-squad')
    return question_answerer_pipe

def get_formatted_text_for_annotation(output):
    colour_map = {'Sex': '#5DD75D',
     'Duration': '#D92E45',
     'Sign_symptom': '#793F41',
     'Frequency': '#232AE7',
     'Detailed_description': '#E1D8D1',
     'History': '#296FB8',
     'Clinical_event': '#E840A7',
     'Lab_value': '#FE90C3',
     'Age': '#31404C',
     'Biological_structure': '#1A4B5B',
     'Diagnostic_procedure': '#804E7A'}
    
    annotated_texts = []
    next_index = 0
    for entity in output:
        if entity['start'] == next_index:
    #         print("found entity")
            extracted_text = text[entity['start']:entity['end']]
    #         print("annotated",annotated_text)
            annotated_texts.append((extracted_text ,entity['entity_group'],colour_map[entity['entity_group']]))
        else:
            unannotated_text = text[next_index:entity['start']-1]
            annotated_texts.append(unannotated_text)
            extracted_text = text[entity['start']:entity['end']]
            annotated_texts.append((extracted_text ,entity['entity_group'],colour_map[entity['entity_group']]))
            next_index =entity['end'] +1
    
    if next_index < len(text):
        annotated_texts.append(text[next_index-1:len(text)-1])
        
    return tuple(annotated_texts)
    
# Model initialization    
pipeline_summarization = init_text_summarization_model()
pipeline_zsl, template = init_zsl_topic_classification()
pipeline_ner =init_ner_pipeline()
pipeline_qa = init_qa_pipeline()

st.header("Intelligent Document Automation")

def get_text_from_ocr_engine():
    return "CASE: A 28-year-old previously healthy man presented with a 6-week history of palpitations. The symptoms occurred during rest, 2–3 times per week, lasted up to 30 minutes at a time and were associated with dyspnea. Except for a grade 2/6 holosystolic tricuspid regurgitation murmur (best heard at the left sternal border with inspiratory accentuation), physical examination yielded unremarkable findings."

with st.sidebar:
    selected_menu = option_menu("Select Option", 
    ["Upload Document", "Extract Text", "Summarize Document", "Extract Entities","Get Answers"], 
        menu_icon="cast", default_index=0)
    

if selected_menu == "Upload Document":
    uploaded_file = st.file_uploader("Choose a file")        
    if uploaded_file is not None:
        ocr_text  = get_text_from_ocr_engine(uploaded_file)
        st.write("Upload Successful")
elif selected_menu == "Extract Text":
    st.write(get_text_from_ocr_engine())
elif selected_menu == "Summarize Document":
    text = get_text_from_ocr_engine()
    with st.spinner("Summarizing Document..."):
        summary_text = pipeline_summarization(review, max_length=130, min_length=10, do_sample=False)
        # Show output
        st.write(summary_text[0]['summary_text'])
        
elif selected_menu == "Extract Entities":
    text = get_text_from_ocr_engine()
    output = pipeline_ner (text)
    entities_text =get_formatted_text_for_annotation(output)
    annotated_text(*entities_text)

elif selected_menu == "Get Answers":
    st.subheader('Question')
    question_text = st.text_input("Type your question")
    context = get_text_from_ocr_engine()
    if question_text:
        result = question_answerer(question=question_text,  context=context)
        st.subheader('Answer')
        st.text(result['answer'])