Spaces:
Sleeping
Sleeping
## | |
from transformers import AutoTokenizer, pipeline | |
from transformers import T5ForConditionalGeneration | |
from transformers import PegasusForConditionalGeneration | |
from transformers import BartForConditionalGeneration | |
import streamlit as st | |
# T5 | |
def get_tidy_tab_t5(): | |
if 'tidy_tab_t5' not in st.session_state: | |
st.session_state.tidy_tab_t5 = load_model_t5() | |
return st.session_state.tidy_tab_t5 | |
def load_model_t5(): | |
model_name="wgcv/tidy-tab-model-t5-small" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = T5ForConditionalGeneration.from_pretrained(model_name) | |
return pipeline('summarization', model=model, tokenizer=tokenizer) | |
def predict_model_t5(text): | |
tidy_tab_t5 = get_tidy_tab_t5() | |
if(tidy_tab_t5): | |
text = "summarize: " + text | |
result = tidy_tab_t5(text, max_length=8, min_length=1) | |
if(len(result)>0): | |
return result[0]['summary_text'] | |
else: | |
return None | |
else: | |
return None | |
# pegasus-xsum | |
def get_tidy_tab_pegasus(): | |
if 'tidy_tab_pegasus' not in st.session_state: | |
st.session_state.tidy_tab_pegasus = load_model_pegasus() | |
return st.session_state.tidy_tab_pegasus | |
def load_model_pegasus(): | |
model_name="wgcv/tidy-tab-model-pegasus-xsum" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = PegasusForConditionalGeneration.from_pretrained(model_name) | |
return pipeline('summarization', model=model, tokenizer=tokenizer) | |
def predict_model_pegasus(text): | |
tidy_tab_pegasus = get_tidy_tab_pegasus() | |
if(tidy_tab_pegasus): | |
text = text | |
result = tidy_tab_pegasus(text, max_length=8, min_length=1) | |
if(len(result)>0): | |
return result[0]['summary_text'] | |
else: | |
return None | |
else: | |
return None | |
# Bart-Large | |
def get_tidy_tab_bart(): | |
if 'tidy_tab_bart' not in st.session_state: | |
st.session_state.tidy_tab_bart = load_model_bart() | |
return st.session_state.tidy_tab_bart | |
def load_model_bart(): | |
model_name="wgcv/tidy-tab-model-bart-large-cnn" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = BartForConditionalGeneration.from_pretrained(model_name) | |
return pipeline('summarization', model=model, tokenizer=tokenizer) | |
def predict_model_bart(text): | |
tidy_tab_bart = get_tidy_tab_bart() | |
if(tidy_tab_bart): | |
text = text | |
result = tidy_tab_bart(text, num_beams=4, max_length=12, min_length=1, do_sample=True ) | |
if(len(result)>0): | |
return result[0]['summary_text'] | |
else: | |
return None | |
else: | |
return None |