from transformers import PreTrainedTokenizerFast from tokenizers import SentencePieceBPETokenizer from transformers import BartForConditionalGeneration import streamlit as st import torch def tokenizer(): tokenizer = PreTrainedTokenizerFast.from_pretrained('Soyoung97/gec_kr') return tokenizer @st.cache(allow_output_mutation=True) def get_model(): model = BartForConditionalGeneration.from_pretrained('Soyoung97/gec_kr') model.eval() return model default_text = '나느 오늘 지베 가써요' model = get_model() tokenizer = tokenizer() st.title("GEC_KR Model Test") text = st.text_area("Input corrputed sentence :", value=default_text) st.markdown("## Original sentence:") st.write(text) if text: st.markdown("## Corrected output") with st.spinner('processing..'): raw_input_ids = tokenizer.encode(text) input_ids = [tokenizer.bos_token_id] + \ raw_input_ids + [tokenizer.eos_token_id] corrected_ids = model.generate(torch.tensor([input_ids]), max_length=256, eos_token_id=1, num_beams=4, early_stopping=True, repetition_penalty=2.0) summ = tokenizer.decode(corrected_ids.squeeze().tolist(), skip_special_tokens=True) st.write(summ)