from transformers import PreTrainedTokenizerFast from tokenizers import SentencePieceBPETokenizer from transformers import BartForConditionalGeneration import streamlit as st import torch import random def tokenizer(): tokenizer = PreTrainedTokenizerFast.from_pretrained('Soyoung97/gec_kr') return tokenizer @st.cache(allow_output_mutation=True) def get_model(): model = BartForConditionalGeneration.from_pretrained('Soyoung97/gec_kr') model.eval() return model default_text = '한국어는 저한테 너무 어려운 언어이었어요.' model = get_model() tokenizer = tokenizer() st.title("Grammatical Error Correction for Korean: Demo") text = st.text_input("Input corrputed sentence :", value=default_text) default_text_list = ['한국어는 저한테 너무 어려운 언어이었어요.', '저는 한국말 배워 안했어요.', '멍머이는 귀엽다', '대학원생살려!', '수지씨가 예쁩니까?', '지난날 인타넷으로 찾아냈다.', '그 제 꿈이 교수기 도는 것입니다'] if st.button("try another example: "): text_button = random.choice(default_text_list) try_this = f"Try this text! : {text_button}" st.write(try_this) st.markdown("## Original sentence:") st.write(text) if text: st.markdown("## Corrected output") with st.spinner('processing..'): raw_input_ids = tokenizer.encode(text) input_ids = [tokenizer.bos_token_id] + \ raw_input_ids + [tokenizer.eos_token_id] corrected_ids = model.generate(torch.tensor([input_ids]), max_length=256, eos_token_id=1, num_beams=4, early_stopping=True, repetition_penalty=2.0) output = tokenizer.decode(corrected_ids.squeeze().tolist(), skip_special_tokens=True) if output == '': output = 'Nothing generated...TT Please try again with different text!' st.write(output)