Peter commited on
Commit
0d868fb
1 Parent(s): 766eaec

add grammar correction

Browse files
Files changed (2) hide show
  1. app.py +13 -1
  2. grammar_improve.py +46 -1
app.py CHANGED
@@ -29,6 +29,7 @@ from grammar_improve import (
29
  remove_repeated_words,
30
  remove_trailing_punctuation,
31
  symspeller,
 
32
  )
33
  from utils import corr
34
 
@@ -77,7 +78,7 @@ def ask_gpt(
77
  chat_pipe,
78
  speaker="person alpha",
79
  responder="person beta",
80
- max_len=196,
81
  top_p=0.95,
82
  top_k=50,
83
  temperature=0.6,
@@ -124,6 +125,7 @@ def ask_gpt(
124
  cln_resp = symspeller(rawtxt, sym_checker=schnellspell)
125
  elif not detect_propers(rawtxt):
126
  cln_resp = neuspell_correct(rawtxt, checker=ns_checker)
 
127
  else:
128
  # no correction needed
129
  cln_resp = rawtxt.strip()
@@ -152,6 +154,14 @@ def get_parser():
152
  default="ethzanalytics/ai-msgbot-gpt2-XL", # default model
153
  help="the model to use for the chatbot on https://huggingface.co/models OR a path to a local model",
154
  )
 
 
 
 
 
 
 
 
155
  parser.add_argument(
156
  "--basic-sc",
157
  required=False,
@@ -174,6 +184,7 @@ if __name__ == "__main__":
174
  default_model = str(args.model)
175
  model_loc = Path(default_model) # if the model is a path, use it
176
  basic_sc = args.basic_sc # whether to use the baseline spellchecker
 
177
  device = 0 if torch.cuda.is_available() else -1
178
  print(f"CUDA avail is {torch.cuda.is_available()}")
179
 
@@ -190,6 +201,7 @@ if __name__ == "__main__":
190
  else:
191
  print("using Neuspell spell checker")
192
  ns_checker = load_ns_checker(fast=False)
 
193
 
194
  print(f"using model stored here: \n {model_loc} \n")
195
  iface = gr.Interface(
 
29
  remove_repeated_words,
30
  remove_trailing_punctuation,
31
  symspeller,
32
+ synthesize_grammar,
33
  )
34
  from utils import corr
35
 
 
78
  chat_pipe,
79
  speaker="person alpha",
80
  responder="person beta",
81
+ max_len=128,
82
  top_p=0.95,
83
  top_k=50,
84
  temperature=0.6,
 
125
  cln_resp = symspeller(rawtxt, sym_checker=schnellspell)
126
  elif not detect_propers(rawtxt):
127
  cln_resp = neuspell_correct(rawtxt, checker=ns_checker)
128
+ cln_resp = synthesize_grammar(corrector=grammarbot, message=cln_resp)
129
  else:
130
  # no correction needed
131
  cln_resp = rawtxt.strip()
 
154
  default="ethzanalytics/ai-msgbot-gpt2-XL", # default model
155
  help="the model to use for the chatbot on https://huggingface.co/models OR a path to a local model",
156
  )
157
+ parser.add_argument(
158
+ "--gram-model",
159
+ required=False,
160
+ type=str,
161
+ default="pszemraj/t5-v1_1-base-ft-jflAUG",
162
+ help="text2text generation model ID from huggingface for the model to correct grammar",
163
+ )
164
+
165
  parser.add_argument(
166
  "--basic-sc",
167
  required=False,
 
184
  default_model = str(args.model)
185
  model_loc = Path(default_model) # if the model is a path, use it
186
  basic_sc = args.basic_sc # whether to use the baseline spellchecker
187
+ gram_model = str(args.gram_model)
188
  device = 0 if torch.cuda.is_available() else -1
189
  print(f"CUDA avail is {torch.cuda.is_available()}")
190
 
 
201
  else:
202
  print("using Neuspell spell checker")
203
  ns_checker = load_ns_checker(fast=False)
204
+ grammarbot = pipeline("'text2text-generation",gram_model, device=device)
205
 
206
  print(f"using model stored here: \n {model_loc} \n")
207
  iface = gr.Interface(
grammar_improve.py CHANGED
@@ -14,7 +14,8 @@ import time
14
  import re
15
  import sys
16
  from symspellpy.symspellpy import SymSpell
17
-
 
18
  from utils import suppress_stdout
19
 
20
 
@@ -108,6 +109,50 @@ def fix_punct_spacing(text: str):
108
 
109
  return cln_text
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  """
113
  start of SymSpell code
 
14
  import re
15
  import sys
16
  from symspellpy.symspellpy import SymSpell
17
+ import transformers
18
+ from transformers import pipeline
19
  from utils import suppress_stdout
20
 
21
 
 
109
 
110
  return cln_text
111
 
112
+ def synthesize_grammar(
113
+ corrector: transformers.pipeline,
114
+ message: str,
115
+ num_beams=4,
116
+ length_penalty=0.9,
117
+ repetition_penalty=1.5,
118
+ no_repeat_ngram_size=4,
119
+ verbose=False,
120
+ ):
121
+ """
122
+ synthesize_grammar - use a SyntaxSynthesizer model to generate a string from a message
123
+
124
+ Parameters
125
+ ----------
126
+ corrector : transformers.pipeline, required, which is the SyntaxSynthesizer model already loaded
127
+ message : str, required, which is the message to be corrected
128
+ num_beams : int, optional, by default 4, which is the number of beams to use for the model
129
+ length_penalty : float, optional, by default 0.9, which is the length penalty to use for the model
130
+ repetition_penalty : float, optional, by default 1.5, which is the repetition penalty to use for the model
131
+ no_repeat_ngram_size : int, optional, by default 4, which is the n-gram size to use for the model
132
+ verbose : bool, optional, by default False, which is whether to print the runtime of the model
133
+
134
+ Returns
135
+ -------
136
+ """
137
+ st = time.perf_counter()
138
+ input_text = clean(message, lower=False)
139
+ results = corrector(
140
+ input_text,
141
+ max_length=int(1.1 * len(input_text)),
142
+ min_length=2 if len(input_text) < 64 else int(0.2 * len(input_text)),
143
+ num_beams=num_beams,
144
+ repetition_penalty=repetition_penalty,
145
+ length_penalty=length_penalty,
146
+ no_repeat_ngram_size=no_repeat_ngram_size,
147
+ early_stopping=True,
148
+ do_sample=False,
149
+ clean_up_tokenization_spaces=True,
150
+ )
151
+ corrected_text = results[0]["generated_text"]
152
+ if verbose:
153
+ rt = round(time.perf_counter() - st, 2)
154
+ print(f"synthesizing took {rt} seconds")
155
+ return corrected_text.strip()
156
 
157
  """
158
  start of SymSpell code