Soyoung97 commited on
Commit
0882f0e
β€’
1 Parent(s): 60848a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -28
app.py CHANGED
@@ -18,22 +18,19 @@ def get_model():
18
  return model
19
 
20
 
21
- default_text = 'ν•œκ΅­μ–΄λŠ” μ €ν•œν…Œ λ„ˆλ¬΄ μ–΄λ €μš΄ μ–Έμ–΄μ΄μ—ˆμ–΄μš”.'
22
- default_text_1 = 'μ œκ°€ μ •μ‹μœΌλ‘œ ν•œκ΅­λ§ λ°°μ›Œ μ•ˆν–ˆμ–΄μš”.'
23
  model = get_model()
24
  tokenizer = tokenizer()
25
  st.title("Grammatical Error Correction for Korean: Demo")
26
 
27
- text = st.text_input("Input corrputed sentence 1 :", value=default_text)
28
- text1 = st.text_input("Input corrputed sentence 2 :", value=default_text_1)
29
 
30
- st.markdown("## Original sentence 1:")
31
  st.write(text)
32
- st.markdown("## Original sentence 2:")
33
- st.write(text1)
34
 
35
  if text:
36
- st.markdown("## Corrected output 1")
37
  with st.spinner('processing..'):
38
  raw_input_ids = tokenizer.encode(text)
39
  input_ids = [tokenizer.bos_token_id] + \
@@ -42,24 +39,7 @@ if text:
42
  max_length=256,
43
  eos_token_id=1,
44
  num_beams=4,
 
45
  repetition_penalty=2.0)
46
- output = tokenizer.decode(corrected_ids.squeeze().tolist(), skip_special_tokens=True)
47
- if output == '':
48
- output = 'Nothing generated...TT Please try again with different text!'
49
- st.write(output)
50
-
51
- if text1:
52
- st.markdown("## Corrected output 2")
53
- with st.spinner('processing..'):
54
- raw_input_ids1 = tokenizer.encode(text)
55
- input_ids1 = [tokenizer.bos_token_id] + \
56
- raw_input_ids1 + [tokenizer.eos_token_id]
57
- corrected_ids1 = model.generate(torch.tensor([input_ids1]),
58
- max_length=256,
59
- eos_token_id=1,
60
- num_beams=4,
61
- repetition_penalty=2.0)
62
- output1 = tokenizer.decode(corrected_ids1.squeeze().tolist(), skip_special_tokens=True)
63
- if output1 == '':
64
- output1 = 'Nothing generated...TT Please try again with different text!'
65
- st.write(output1)
 
18
  return model
19
 
20
 
21
+ default_text = 'ν•œκ΅­μ–΄λŠ” μ €ν•œν…Œ λ„ˆλ¬΄ μ–΄λ €μš΄ μ–Έμ–΄μ΄μ—ˆμ–΄μš”. μ €λŠ” ν•œκ΅­λ§ λ°°μ›Œ μ•ˆν—€μ–΄μš”.'
22
+
23
  model = get_model()
24
  tokenizer = tokenizer()
25
  st.title("Grammatical Error Correction for Korean: Demo")
26
 
27
+ text = st.text_input("Input corrputed sentence :", value=default_text)
 
28
 
29
+ st.markdown("## Original sentence:")
30
  st.write(text)
 
 
31
 
32
  if text:
33
+ st.markdown("## Corrected output")
34
  with st.spinner('processing..'):
35
  raw_input_ids = tokenizer.encode(text)
36
  input_ids = [tokenizer.bos_token_id] + \
 
39
  max_length=256,
40
  eos_token_id=1,
41
  num_beams=4,
42
+ early_stopping=True,
43
  repetition_penalty=2.0)
44
+ summ = tokenizer.decode(corrected_ids.squeeze().tolist(), skip_special_tokens=True)
45
+ st.write(summ)