Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -18,22 +18,19 @@ def get_model():
|
|
18 |
return model
|
19 |
|
20 |
|
21 |
-
default_text = 'νκ΅μ΄λ μ νν
λ무 μ΄λ €μ΄ μΈμ΄μ΄μμ΄μ.'
|
22 |
-
|
23 |
model = get_model()
|
24 |
tokenizer = tokenizer()
|
25 |
st.title("Grammatical Error Correction for Korean: Demo")
|
26 |
|
27 |
-
text = st.text_input("Input corrputed sentence
|
28 |
-
text1 = st.text_input("Input corrputed sentence 2 :", value=default_text_1)
|
29 |
|
30 |
-
st.markdown("## Original sentence
|
31 |
st.write(text)
|
32 |
-
st.markdown("## Original sentence 2:")
|
33 |
-
st.write(text1)
|
34 |
|
35 |
if text:
|
36 |
-
st.markdown("## Corrected output
|
37 |
with st.spinner('processing..'):
|
38 |
raw_input_ids = tokenizer.encode(text)
|
39 |
input_ids = [tokenizer.bos_token_id] + \
|
@@ -42,24 +39,7 @@ if text:
|
|
42 |
max_length=256,
|
43 |
eos_token_id=1,
|
44 |
num_beams=4,
|
|
|
45 |
repetition_penalty=2.0)
|
46 |
-
|
47 |
-
|
48 |
-
output = 'Nothing generated...TT Please try again with different text!'
|
49 |
-
st.write(output)
|
50 |
-
|
51 |
-
if text1:
|
52 |
-
st.markdown("## Corrected output 2")
|
53 |
-
with st.spinner('processing..'):
|
54 |
-
raw_input_ids1 = tokenizer.encode(text)
|
55 |
-
input_ids1 = [tokenizer.bos_token_id] + \
|
56 |
-
raw_input_ids1 + [tokenizer.eos_token_id]
|
57 |
-
corrected_ids1 = model.generate(torch.tensor([input_ids1]),
|
58 |
-
max_length=256,
|
59 |
-
eos_token_id=1,
|
60 |
-
num_beams=4,
|
61 |
-
repetition_penalty=2.0)
|
62 |
-
output1 = tokenizer.decode(corrected_ids1.squeeze().tolist(), skip_special_tokens=True)
|
63 |
-
if output1 == '':
|
64 |
-
output1 = 'Nothing generated...TT Please try again with different text!'
|
65 |
-
st.write(output1)
|
|
|
18 |
return model
|
19 |
|
20 |
|
21 |
+
default_text = 'νκ΅μ΄λ μ νν
λ무 μ΄λ €μ΄ μΈμ΄μ΄μμ΄μ. μ λ νκ΅λ§ λ°°μ μνμ΄μ.'
|
22 |
+
|
23 |
model = get_model()
|
24 |
tokenizer = tokenizer()
|
25 |
st.title("Grammatical Error Correction for Korean: Demo")
|
26 |
|
27 |
+
text = st.text_input("Input corrputed sentence :", value=default_text)
|
|
|
28 |
|
29 |
+
st.markdown("## Original sentence:")
|
30 |
st.write(text)
|
|
|
|
|
31 |
|
32 |
if text:
|
33 |
+
st.markdown("## Corrected output")
|
34 |
with st.spinner('processing..'):
|
35 |
raw_input_ids = tokenizer.encode(text)
|
36 |
input_ids = [tokenizer.bos_token_id] + \
|
|
|
39 |
max_length=256,
|
40 |
eos_token_id=1,
|
41 |
num_beams=4,
|
42 |
+
early_stopping=True,
|
43 |
repetition_penalty=2.0)
|
44 |
+
summ = tokenizer.decode(corrected_ids.squeeze().tolist(), skip_special_tokens=True)
|
45 |
+
st.write(summ)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|