|
import streamlit as st
|
|
import torch
|
|
import spacy
|
|
|
|
|
|
|
|
from Dataset import SkimlitDataset
|
|
from Embeddings import get_embeddings
|
|
from Model import SkimlitModel
|
|
from Tokenizer import Tokenizer
|
|
from LabelEncoder import LabelEncoder
|
|
from MakePredictions import make_skimlit_predictions
|
|
from RandomAbstract import Choose_Random_text
|
|
|
|
MODEL_PATH = 'skimlit-model-final-1.pt'
|
|
TOKENIZER_PATH = 'tokenizer.json'
|
|
LABEL_ENOCDER_PATH = "label_encoder.json"
|
|
EMBEDDING_FILE_PATH = 'glove.6B.300d.txt'
|
|
|
|
@st.cache()
|
|
def create_utils(model_path, tokenizer_path, label_encoder_path, embedding_file_path):
|
|
tokenizer = Tokenizer.load(fp=tokenizer_path)
|
|
label_encoder = LabelEncoder.load(fp=label_encoder_path)
|
|
embedding_matrix = get_embeddings(embedding_file_path, tokenizer, 300)
|
|
model = SkimlitModel(embedding_dim=300, vocab_size=len(tokenizer), hidden_dim=128, n_layers=3, linear_output=128, num_classes=len(label_encoder), pretrained_embeddings=embedding_matrix)
|
|
model.load_state_dict(torch.load(model_path, map_location='cpu'))
|
|
print(model)
|
|
return model, tokenizer, label_encoder
|
|
|
|
def model_prediction(abstract, model, tokenizer, label_encoder):
|
|
objective = ''
|
|
background = ''
|
|
method = ''
|
|
conclusion = ''
|
|
result = ''
|
|
|
|
lines, pred = make_skimlit_predictions(abstract, model, tokenizer, label_encoder)
|
|
|
|
|
|
for i, line in enumerate(lines):
|
|
if pred[i] == 'OBJECTIVE':
|
|
objective = objective + line
|
|
|
|
elif pred[i] == 'BACKGROUND':
|
|
background = background + line
|
|
|
|
elif pred[i] == 'METHODS':
|
|
method = method + line
|
|
|
|
elif pred[i] == 'RESULTS':
|
|
result = result + line
|
|
|
|
elif pred[i] == 'CONCLUSIONS':
|
|
conclusion = conclusion + line
|
|
|
|
return objective, background, method, conclusion, result
|
|
|
|
|
|
|
|
def main():
|
|
|
|
st.set_page_config(
|
|
page_title="SkimLit",
|
|
page_icon="📄",
|
|
layout="wide",
|
|
initial_sidebar_state="expanded"
|
|
)
|
|
|
|
st.title('SkimLit📄🔥')
|
|
st.caption('An NLP model to classify abstract sentences into the role they play (e.g. objective, methods, results, etc..) to enable researchers to skim through the literature and dive deeper when necessary.')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
col1, col2 = st.columns(2)
|
|
|
|
with col1:
|
|
st.write('#### Entre Abstract Here !!')
|
|
abstract = st.text_area(label='', height=200)
|
|
|
|
agree = st.checkbox('Show Example Abstract')
|
|
predict = st.button('Extract !')
|
|
|
|
if agree:
|
|
example_input = Choose_Random_text()
|
|
st.info(example_input)
|
|
|
|
|
|
if predict:
|
|
with col2:
|
|
with st.spinner('Wait for prediction....'):
|
|
skimlit_model, tokenizer, label_encoder = create_utils(MODEL_PATH, TOKENIZER_PATH, LABEL_ENOCDER_PATH, EMBEDDING_FILE_PATH)
|
|
objective, background, methods, conclusion, result = model_prediction(abstract, skimlit_model, tokenizer, label_encoder)
|
|
|
|
st.markdown(f'### Objective : ')
|
|
st.info(objective)
|
|
|
|
st.markdown(f'### Background : ')
|
|
st.info(background)
|
|
|
|
st.markdown(f'### Methods : ')
|
|
st.info(methods)
|
|
|
|
st.markdown(f'### Result : ')
|
|
st.info(result)
|
|
|
|
st.markdown(f'### Conclusion : ')
|
|
st.info(conclusion)
|
|
|
|
|
|
|
|
|
|
if __name__=='__main__':
|
|
main() |