gec-korean-demo / app.py
Soyoung97's picture
Update app.py
f1cbe9d
raw
history blame
No virus
2.06 kB
from transformers import PreTrainedTokenizerFast
from tokenizers import SentencePieceBPETokenizer
from transformers import BartForConditionalGeneration
import streamlit as st
import torch
import random
def tokenizer():
tokenizer = PreTrainedTokenizerFast.from_pretrained('Soyoung97/gec_kr')
return tokenizer
@st.cache(allow_output_mutation=True)
def get_model():
model = BartForConditionalGeneration.from_pretrained('Soyoung97/gec_kr')
model.eval()
return model
default_text = 'ν•œκ΅­μ–΄λŠ” μ €ν•œν…Œ λ„ˆλ¬΄ μ–΄λ €μš΄ μ–Έμ–΄μ΄μ—ˆμ–΄μš”.'
model = get_model()
tokenizer = tokenizer()
st.title("Grammatical Error Correction for Korean: Demo")
text = st.text_input("Input corrputed sentence :", value=default_text)
default_text_list = ['ν•œκ΅­μ–΄λŠ” μ €ν•œν…Œ λ„ˆλ¬΄ μ–΄λ €μš΄ μ–Έμ–΄μ΄μ—ˆμ–΄μš”.', 'μ €λŠ” ν•œκ΅­λ§ λ°°μ›Œ μ•ˆν–ˆμ–΄μš”.', 'λ©λ¨Έμ΄λŠ” κ·€μ—½λ‹€', 'λŒ€ν•™μ›μƒμ‚΄λ €!', 'μˆ˜μ§€μ”¨κ°€ μ˜ˆμ©λ‹ˆκΉŒ?', 'μ§€λ‚œλ‚  μΈνƒ€λ„·μœΌλ‘œ μ°Ύμ•„λƒˆλ‹€.', 'κ·Έ 제 꿈이 ꡐ수기 λ„λŠ” κ²ƒμž…λ‹ˆλ‹€']
if st.button("try another example: "):
text_button = random.choice(default_text_list)
try_this = f"Try this text! : {text_button}"
st.write(try_this)
st.markdown("## Original sentence:")
st.write(text)
if text:
st.markdown("## Corrected output")
with st.spinner('processing..'):
raw_input_ids = tokenizer.encode(text)
input_ids = [tokenizer.bos_token_id] + \
raw_input_ids + [tokenizer.eos_token_id]
corrected_ids = model.generate(torch.tensor([input_ids]),
max_length=256,
eos_token_id=1,
num_beams=4,
early_stopping=True,
repetition_penalty=2.0)
output = tokenizer.decode(corrected_ids.squeeze().tolist(), skip_special_tokens=True)
if output == '':
output = 'Nothing generated...TT Please try again with different text!'
st.write(output)