Tidy-Tabs-Titles / model.py
wgcv's picture
first test
b7a1a13
##
from transformers import AutoTokenizer, pipeline
from transformers import T5ForConditionalGeneration
from transformers import PegasusForConditionalGeneration
from transformers import BartForConditionalGeneration
import streamlit as st
# T5
def get_tidy_tab_t5():
if 'tidy_tab_t5' not in st.session_state:
st.session_state.tidy_tab_t5 = load_model_t5()
return st.session_state.tidy_tab_t5
def load_model_t5():
model_name="wgcv/tidy-tab-model-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
return pipeline('summarization', model=model, tokenizer=tokenizer)
def predict_model_t5(text):
tidy_tab_t5 = get_tidy_tab_t5()
if(tidy_tab_t5):
text = "summarize: " + text
result = tidy_tab_t5(text, max_length=8, min_length=1)
if(len(result)>0):
return result[0]['summary_text']
else:
return None
else:
return None
# pegasus-xsum
def get_tidy_tab_pegasus():
if 'tidy_tab_pegasus' not in st.session_state:
st.session_state.tidy_tab_pegasus = load_model_pegasus()
return st.session_state.tidy_tab_pegasus
def load_model_pegasus():
model_name="wgcv/tidy-tab-model-pegasus-xsum"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name)
return pipeline('summarization', model=model, tokenizer=tokenizer)
def predict_model_pegasus(text):
tidy_tab_pegasus = get_tidy_tab_pegasus()
if(tidy_tab_pegasus):
text = text
result = tidy_tab_pegasus(text, max_length=8, min_length=1)
if(len(result)>0):
return result[0]['summary_text']
else:
return None
else:
return None
# Bart-Large
def get_tidy_tab_bart():
if 'tidy_tab_bart' not in st.session_state:
st.session_state.tidy_tab_bart = load_model_bart()
return st.session_state.tidy_tab_bart
def load_model_bart():
model_name="wgcv/tidy-tab-model-bart-large-cnn"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)
return pipeline('summarization', model=model, tokenizer=tokenizer)
def predict_model_bart(text):
tidy_tab_bart = get_tidy_tab_bart()
if(tidy_tab_bart):
text = text
result = tidy_tab_bart(text, num_beams=4, max_length=12, min_length=1, do_sample=True )
if(len(result)>0):
return result[0]['summary_text']
else:
return None
else:
return None