## from transformers import AutoTokenizer, pipeline from transformers import T5ForConditionalGeneration from transformers import PegasusForConditionalGeneration from transformers import BartForConditionalGeneration import streamlit as st # T5 def get_tidy_tab_t5(): if 'tidy_tab_t5' not in st.session_state: st.session_state.tidy_tab_t5 = load_model_t5() return st.session_state.tidy_tab_t5 def load_model_t5(): model_name="wgcv/tidy-tab-model-t5-small" tokenizer = AutoTokenizer.from_pretrained(model_name) model = T5ForConditionalGeneration.from_pretrained(model_name) return pipeline('summarization', model=model, tokenizer=tokenizer) def predict_model_t5(text): tidy_tab_t5 = get_tidy_tab_t5() if(tidy_tab_t5): text = "summarize: " + text result = tidy_tab_t5(text, max_length=8, min_length=1) if(len(result)>0): return result[0]['summary_text'] else: return None else: return None # pegasus-xsum def get_tidy_tab_pegasus(): if 'tidy_tab_pegasus' not in st.session_state: st.session_state.tidy_tab_pegasus = load_model_pegasus() return st.session_state.tidy_tab_pegasus def load_model_pegasus(): model_name="wgcv/tidy-tab-model-pegasus-xsum" tokenizer = AutoTokenizer.from_pretrained(model_name) model = PegasusForConditionalGeneration.from_pretrained(model_name) return pipeline('summarization', model=model, tokenizer=tokenizer) def predict_model_pegasus(text): tidy_tab_pegasus = get_tidy_tab_pegasus() if(tidy_tab_pegasus): text = text result = tidy_tab_pegasus(text, max_length=8, min_length=1) if(len(result)>0): return result[0]['summary_text'] else: return None else: return None # Bart-Large def get_tidy_tab_bart(): if 'tidy_tab_bart' not in st.session_state: st.session_state.tidy_tab_bart = load_model_bart() return st.session_state.tidy_tab_bart def load_model_bart(): model_name="wgcv/tidy-tab-model-bart-large-cnn" tokenizer = AutoTokenizer.from_pretrained(model_name) model = BartForConditionalGeneration.from_pretrained(model_name) return pipeline('summarization', model=model, tokenizer=tokenizer) def predict_model_bart(text): tidy_tab_bart = get_tidy_tab_bart() if(tidy_tab_bart): text = text result = tidy_tab_bart(text, num_beams=4, max_length=12, min_length=1, do_sample=True ) if(len(result)>0): return result[0]['summary_text'] else: return None else: return None