import torch import streamlit as st from streamlit import components import pandas as pd from transformers import BartTokenizer, BartForConditionalGeneration from transformers import T5Tokenizer, T5ForConditionalGeneration import evaluate from datasets import load_dataset from transformers import AutoTokenizer, LongT5ForConditionalGeneration import numpy as np from math import ceil import en_core_web_lg from collections import Counter from string import punctuation # Gensim import gensim from gensim.summarization import summarize import spacy nlp = en_core_web_lg.load() st.set_page_config(page_title ='Clinical Note Summarization', #page_icon= "Notes", layout='wide') st.title('Clinical Note Summarization') st.sidebar.markdown('Using transformer model') ## Loading in dataset #df = pd.read_csv('mtsamples_small.csv',index_col=0) df = pd.read_csv("demo_shpi_w_rouge25Nov.csv") #df.shape df['HADM_ID'] = df['HADM_ID'].astype(str).apply(lambda x: x.replace('.0','')) ##Renaming column #df.rename(columns={'patient id':'Patient_ID', # 'hospital admission id':'Admission_ID', # 'transcription':'Original_Text'}, inplace = True) #Renaming column df.rename(columns={'SUBJECT_ID':'Patient_ID', 'HADM_ID':'Admission_ID', 'hpi_input_text':'Original_Text', 'hpi_reference_summary':'Reference_text'}, inplace = True) #data.rename(columns={'gdp':'log(gdp)'}, inplace=True) #Filter selection st.sidebar.header("Search for Patient:") patientid = df['Patient_ID'] patient = st.sidebar.selectbox('Select Patient ID:', patientid) admissionid = df['Admission_ID'].loc[df['Patient_ID'] == patient] HospitalAdmission = st.sidebar.selectbox(' ', admissionid) #Another way to for filter selection #patient = st.sidebar.multiselect( # "Select Patient ID:", # options=df['Patient_ID'].unique(), # default= None #) #HospitalAdmission = st.sidebar.multiselect( # "Select Hospital Admission ID:", # options=df['Admission_ID'].unique(), # #default=df['Admission_ID'].unique() # default = None #) # List of Model available model = st.sidebar.selectbox('Select Model', ('BART','BERT','BertGPT2','Gensim','LexRank','Long T5','Luhn','Pysummarization','SBERT Summary Tokenizer','T5','T5 Seq2Seq','T5-Base','TextRank')) if model == 'BART': _num_beams = 4 _no_repeat_ngram_size = 3 _length_penalty = 1 _min_length = 12 _max_length = 128 _early_stopping = True else: _num_beams = 4 _no_repeat_ngram_size = 3 _length_penalty = 2 _min_length = 30 _max_length = 200 _early_stopping = True col3,col4 = st.columns(2) patientid = col3.write(f"Patient ID: {patient} ") admissionid =col4.write(f"Admission ID: {HospitalAdmission} ") col1, col2 = st.columns(2) _min_length = col1.number_input("Minimum Length", value=_min_length) _max_length = col2.number_input("Maximun Length", value=_max_length) ##_early_stopping = col3.number_input("early_stopping", value=_early_stopping) #text = st.text_area('Input Clinical Note here') # Query out relevant Clinical notes original_text = df.query( "Patient_ID == @patient & Admission_ID == @HospitalAdmission" ) original_text2 = original_text['Original_Text'].values runtext =st.text_area('Input Clinical Note here:', str(original_text2), height=300) reference_text = original_text['Reference_text'].values def visualize(sentence_list, best_sentences): text = '' #display(HTML(f'

Summary - {title}

')) for run_text in sentence_list: if run_text in best_sentences: #text += ' ' + str(run_text).replace(run_text, f"{run_text}") text += ' ' + str(run_text).replace(run_text, f"{run_text}") else: text += ' ' + run_text # display(HTML(f""" {text} """)) output = '' best_sentences = [] for run_text in output: #print(sentence) best_sentences.append(str(run_text)) return text #===== Pysummarization ===== from pysummarization.nlpbase.auto_abstractor import AutoAbstractor from pysummarization.tokenizabledoc.simple_tokenizer import SimpleTokenizer from pysummarization.abstractabledoc.top_n_rank_abstractor import TopNRankAbstractor import regex as re auto_abstractor = AutoAbstractor() auto_abstractor.tokenizable_doc = SimpleTokenizer() auto_abstractor.delimiter_list = [".", "\n"] abstractable_doc = TopNRankAbstractor() def pysummarizer(input_text): # print(type(text)) summary = auto_abstractor.summarize(input_text, abstractable_doc) best_sentences=[] #summary_clean = ''.join([str(sentence).capitalize() for sentence in summary['summarize_result'] for summary['summarize_result'] in auto_abstractor.summarize(text, abstractable_doc)]) for sentence in summary['summarize_result']: best_sentences.append(re.sub(r'\s+', ' ', sentence).strip()) clean_summary=''.join(sentence for sentence in best_sentences) return clean_summary ##===== BERT Summary tokenizer ===== def BertSummarizer(input_text): from transformers import BigBirdTokenizer from summarizer import Summarizer bertsummarizer = Summarizer() model = Summarizer() result = model(input_text,ratio=0.4) return result ##===== SBERT ===== from summarizer.sbert import SBertSummarizer Sbertmodel = SBertSummarizer('paraphrase-MiniLM-L6-v2') def Sbert(input_text): # Sbertresult = Sbertmodel(text, num_sentences=3) Sbertresult = Sbertmodel(input_text, ratio=0.4) return Sbertresult ##===== T5 Seq2Seq ===== def t5seq2seq(input_text): import torch import torch.nn.functional as F from transformers import AutoModelForSeq2SeqLM, AutoTokenizer model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") tokenizer = AutoTokenizer.from_pretrained("t5-base") inputs = tokenizer("summarize: " + input_text, return_tensors="pt", max_length=512, truncation=True) outputs = model.generate(inputs["input_ids"], max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True) summary= tokenizer.decode(outputs[0], skip_special_tokens=True) return summary def BertGPT2(input_text): #import nlp # BioClinicalBert with BERT2GPT2 model with GPT2 decoder from transformers import BertTokenizer, GPT2Tokenizer, EncoderDecoderModel from transformers import AutoTokenizer, AutoModel device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16") model.to(device) #bert_tokenizer = BertTokenizer.from_pretrained("bert-base-cased") bert_tokenizer= AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") # CLS token will work as BOS token bert_tokenizer.bos_token = bert_tokenizer.cls_token # SEP token will work as EOS token bert_tokenizer.eos_token = bert_tokenizer.sep_token # make sure GPT2 appends EOS in begin and end def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] return outputs GPT2Tokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") # set pad_token_id to unk_token_id -> be careful here as unk_token_id == eos_token_id == bos_token_id gpt2_tokenizer.pad_token = gpt2_tokenizer.unk_token # set decoding params model.config.decoder_start_token_id = gpt2_tokenizer.bos_token_id model.config.eos_token_id = gpt2_tokenizer.eos_token_id model.config.max_length = 142 model.config.min_length = 56 model.config.no_repeat_ngram_size = 3 model.early_stopping = True model.length_penalty = 2.0 model.num_beams = 4 #test_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="test") batch_size = 64 def Sbertmodel(batch): # Tokenizer will automatically set [BOS] [EOS] # cut off at BERT max length 512 inputs = bert_tokenizer(batch, padding="max_length", truncation=True, max_length=512, return_tensors="pt") input_ids = inputs.input_ids.to("cuda") attention_mask = inputs.attention_mask.to("cuda") outputs = model.generate(input_ids, attention_mask=attention_mask) # all special tokens including will be removed output_str = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True) #batch["pred"] = output_str return output_str Sbert(input_text) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") def run_model(input_text): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if model == "BART": bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") input_text = str(input_text) input_text = ' '.join(input_text.split()) input_tokenized = bart_tokenizer.encode(input_text, return_tensors='pt').to(device) summary_ids = bart_model.generate(input_tokenized, num_beams=_num_beams, no_repeat_ngram_size=_no_repeat_ngram_size, length_penalty=_length_penalty, min_length=_min_length, max_length=_max_length, early_stopping=_early_stopping) output = [bart_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids] st.write('Summary') st.success(output[0]) elif model == "T5": t5_model = T5ForConditionalGeneration.from_pretrained("t5-base") t5_tokenizer = T5Tokenizer.from_pretrained("t5-base") input_text = str(input_text).replace('\n', '') input_text = ' '.join(input_text.split()) input_tokenized = t5_tokenizer.encode(input_text, return_tensors="pt").to(device) summary_task = torch.tensor([[21603, 10]]).to(device) input_tokenized = torch.cat([summary_task, input_tokenized], dim=-1).to(device) summary_ids = t5_model.generate(input_tokenized, num_beams=_num_beams, no_repeat_ngram_size=_no_repeat_ngram_size, length_penalty=_length_penalty, min_length=_min_length, max_length=_max_length, early_stopping=_early_stopping) output = [t5_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids] st.write('Summary') st.success(output[0]) elif model == "Gensim": output=summarize(str(input_text)) st.write('Summary') st.success(output) elif model == "Pysummarization": output = pysummarizer(input_text) st.write('Summary') st.success(output) elif model == "BERT": output = BertSummarizer(input_text) st.write('Summary') st.success(output) elif model == "SBERT Summary Tokenizer": output = Sbert(input_text) st.write('Summary') st.success(output) elif model == "T5 Seq2Seq": output = t5seq2seq(input_text) st.write('Summary') st.success(output) elif model == "BertGPT2": #Not working correctly. to work on it later on output = BertGPT2(input_text) st.write('Summary') st.success(output) if st.button('Submit'): run_model(runtext) st.markdown('testing if this **works**', unsafe_allow_html=True) st.text_area('Reference text', str(reference_text)) st.text_area(visualize(runtext,reference_text),unsafe_allow_html=True)