Spaces:
Runtime error
Runtime error
""" | |
grammar_improve.py - this .py script contains functions to improve the grammar of a user's input or the models output. | |
""" | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
import math | |
import pprint as pp | |
import re | |
import time | |
import neuspell | |
import transformers | |
from cleantext import clean | |
from neuspell import BertChecker, SclstmChecker | |
from symspellpy.symspellpy import SymSpell | |
from utils import suppress_stdout | |
def detect_propers(text: str): | |
""" | |
detect_propers - detect if a string contains proper nouns | |
Args: | |
text (str): [string to be checked] | |
Returns: | |
[bool]: [True if string contains proper nouns] | |
""" | |
pat = re.compile(r"(?:\w+['’])?\w+(?:-(?:\w+['’])?\w+)*") | |
return bool(pat.search(text)) | |
def fix_punct_spaces(string): | |
""" | |
fix_punct_spaces - replace spaces around punctuation with punctuation. For example, "hello , there" -> "hello, there" | |
Parameters | |
---------- | |
string : str, required, input string to be corrected | |
Returns | |
------- | |
str, corrected string | |
""" | |
fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,]+)*)\s*") | |
string = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), string) | |
return string.strip() | |
def split_sentences(text: str): | |
""" | |
split_sentences - split a string into a list of sentences that keep their ending punctuation. powered by regex witchcraft | |
Args: | |
text (str): [string to be split] | |
Returns: | |
[list]: [list of strings] | |
""" | |
return re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", text) | |
def remove_repeated_words(bot_response): | |
""" | |
remove_repeated_words - remove repeated words from a string, returning only the first instance of each word | |
Parameters | |
---------- | |
bot_response : str | |
string to remove repeated words from | |
Returns | |
------- | |
str | |
string containing the first instance of each word | |
""" | |
words = bot_response.split() | |
unique_words = [] | |
for word in words: | |
if word not in unique_words: | |
unique_words.append(word) | |
return " ".join(unique_words) | |
def remove_trailing_punctuation(text: str, fuLL_strip=False): | |
""" | |
remove_trailing_punctuation - remove trailing punctuation from a string. Purpose is to seem more natural to end users | |
Args: | |
text (str): [string to be cleaned] | |
Returns: | |
[str]: [cleaned string] | |
""" | |
if fuLL_strip: | |
return text.strip("?!.,;:") | |
else: | |
return text.strip(".,;:") | |
def fix_punct_spacing(text: str): | |
"""fix_punct_spacing - fix spacing around punctuation""" | |
fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,]+)*)\s*") | |
spc_text = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), text) | |
cln_text = re.sub(r"(\W)(?=\1)", "", spc_text) | |
return cln_text | |
def synthesize_grammar( | |
corrector: transformers.pipeline, | |
message: str, | |
num_beams=4, | |
length_penalty=0.9, | |
repetition_penalty=1.5, | |
no_repeat_ngram_size=4, | |
verbose=False, | |
): | |
""" | |
synthesize_grammar - use a SyntaxSynthesizer model to generate a string from a message | |
Parameters | |
---------- | |
corrector : transformers.pipeline, required, which is the SyntaxSynthesizer model already loaded | |
message : str, required, which is the message to be corrected | |
num_beams : int, optional, by default 4, which is the number of beams to use for the model | |
length_penalty : float, optional, by default 0.9, which is the length penalty to use for the model | |
repetition_penalty : float, optional, by default 1.5, which is the repetition penalty to use for the model | |
no_repeat_ngram_size : int, optional, by default 4, which is the n-gram size to use for the model | |
verbose : bool, optional, by default False, which is whether to print the runtime of the model | |
Returns | |
------- | |
""" | |
st = time.perf_counter() | |
input_text = clean(message, lower=False) | |
input_len = len(corrector.tokenizer(input_text).input_ids) | |
results = corrector( | |
input_text, | |
max_length=int(1.1 * input_len), | |
min_length=2 if input_len < 64 else int(0.2 * input_len), | |
num_beams=num_beams, | |
repetition_penalty=repetition_penalty, | |
length_penalty=length_penalty, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
early_stopping=True, | |
do_sample=False, | |
clean_up_tokenization_spaces=True, | |
) | |
corrected_text = results[0]["generated_text"] | |
if verbose: | |
rt = round(time.perf_counter() - st, 2) | |
print(f"synthesizing took {rt} seconds") | |
return corrected_text.strip() | |
""" | |
start of SymSpell code | |
""" | |
def symspeller( | |
my_string: str, | |
sym_checker=None, | |
max_dist: int = 2, | |
prefix_length: int = 7, | |
ignore_non_words=True, | |
dictionary_path: str = None, | |
bigram_path: str = None, | |
verbose=False, | |
): | |
""" | |
symspeller - a wrapper for the SymSpell class from symspellpy | |
Parameters | |
---------- | |
my_string : str, required, default=None, the string to be checked | |
sym_checker : SymSpell, optional, default=None, the SymSpell object to use | |
max_dist : int, optional, default=3, the maximum distance to look for replacements | |
prefix_length : int, optional, default=7, the length of the prefixes to use | |
ignore_non_words : bool, optional, default=True, whether to ignore non-words | |
dictionary_path : str, optional, default=None, the path to the dictionary file | |
bigram_path : str, optional, default=None, the path to the bigram dictionary file | |
verbose : bool, optional, default=False, whether to print the results | |
Returns | |
------- | |
list, | |
""" | |
assert len(my_string) > 0, "entered string for correction is empty" | |
if sym_checker is None: | |
# need to create a new class object. user can specify their own dictionary and bigram files | |
if verbose: | |
print("creating new SymSpell object") | |
sym_checker = build_symspell_obj( | |
edit_dist=max_dist, | |
prefix_length=prefix_length, | |
dictionary_path=dictionary_path, | |
bigram_path=bigram_path, | |
) | |
else: | |
if verbose: | |
print("using existing SymSpell object") | |
# max edit distance per lookup (per single word, not per whole input string) | |
suggestions = sym_checker.lookup_compound( | |
my_string, | |
max_edit_distance=max_dist, | |
ignore_non_words=ignore_non_words, | |
ignore_term_with_digits=True, | |
transfer_casing=True, | |
) | |
if verbose: | |
print(f"{len(suggestions)} suggestions found") | |
print(f"the original string is:\n\t{my_string}") | |
sug_list = [sug.term for sug in suggestions] | |
print(f"suggestions:\n\t{sug_list}\n") | |
if len(suggestions) < 1: | |
return clean(my_string) # no correction because no suggestions | |
else: | |
first_result = suggestions[0] # first result is the most likely | |
return first_result._term | |
def build_symspell_obj( | |
edit_dist=2, | |
prefix_length=7, | |
dictionary_path=None, | |
bigram_path=None, | |
): | |
""" | |
build_symspell_obj [build a SymSpell object] | |
Args: | |
verbose (bool, optional): Defaults to False. | |
Returns: | |
SymSpell: a SymSpell object | |
""" | |
dictionary_path = ( | |
r"symspell_rsc/frequency_dictionary_en_82_765.txt" | |
if dictionary_path is None | |
else dictionary_path | |
) | |
bigram_path = ( | |
r"symspell_rsc/frequency_bigramdictionary_en_243_342.txt" | |
if bigram_path is None | |
else bigram_path | |
) | |
sym_checker = SymSpell( | |
max_dictionary_edit_distance=edit_dist + 2, prefix_length=prefix_length | |
) | |
# term_index is the column of the term and count_index is the | |
# column of the term frequency | |
sym_checker.load_dictionary(dictionary_path, term_index=0, count_index=1) | |
sym_checker.load_bigram_dictionary(bigram_path, term_index=0, count_index=2) | |
return sym_checker | |
""" | |
# if using t5b_correction to check for spelling errors, use this code to initialize the objects | |
import torch | |
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
model_name = 'deep-learning-analytics/GrammarCorrector' | |
# torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
torch_device = 'cpu' | |
gc_tokenizer = T5Tokenizer.from_pretrained(model_name) | |
gc_model = T5ForConditionalGeneration.from_pretrained(model_name).to(torch_device) | |
""" | |
def t5b_correction(prompt: str, korrektor, verbose=False, beams=4): | |
""" | |
t5b_correction - correct a string using a text2textgen pipeline model from transformers | |
Parameters | |
---------- | |
prompt : str, required, input prompt to be corrected | |
korrektor : transformers.pipeline, required, pipeline object | |
verbose : bool, optional, whether to print the corrected prompt. Defaults to False. | |
beams : int, optional, number of beams to use for the correction. Defaults to 4. | |
Returns | |
------- | |
str, corrected prompt | |
""" | |
p_min_len = int(math.ceil(0.9 * len(prompt))) | |
p_max_len = int(math.ceil(1.1 * len(prompt))) | |
if verbose: | |
print(f"setting min to {p_min_len} and max to {p_max_len}\n") | |
gcorr_result = korrektor( | |
f"grammar: {prompt}", | |
return_text=True, | |
clean_up_tokenization_spaces=True, | |
num_beams=beams, | |
max_length=p_max_len, | |
repetition_penalty=1.3, | |
length_penalty=0.2, | |
no_repeat_ngram_size=2, | |
) | |
if verbose: | |
print(f"grammar correction result: \n\t{gcorr_result}\n") | |
return gcorr_result | |
def all_neuspell_chkrs(): | |
""" | |
disp_neuspell_chkrs - display the neuspell checkers available | |
Parameters | |
---------- | |
None | |
Returns | |
------- | |
checker_opts - list of checkers available | |
""" | |
checker_opts = dir(neuspell) | |
print(f"\navailable checkers:") | |
pp.pprint(checker_opts, indent=4, compact=True) | |
return checker_opts | |
def load_ns_checker(customckr=None, fast=False): | |
""" | |
load_ns_checker - helper function, load / "set up" a neuspell checker from huggingface transformers | |
Args: | |
customckr (neuspell.NeuSpell): [neuspell checker object], optional, if not provided, will load the default checker | |
Returns: | |
[neuspell.NeuSpell]: [neuspell checker object] | |
""" | |
st = time.perf_counter() | |
# stop all printing to the console | |
with suppress_stdout(): | |
if customckr is None and not fast: | |
checker = BertChecker( | |
pretrained=True | |
) # load the default checker, has the best balance | |
elif customckr is None and fast: | |
checker = SclstmChecker( | |
pretrained=True | |
) # this one is faster but not as accurate | |
else: | |
checker = customckr(pretrained=True) | |
rt_min = (time.perf_counter() - st) / 60 | |
# return to standard logging level | |
print(f"\n\nloaded checker in {rt_min} minutes") | |
return checker | |
def neuspell_correct(input_text: str, checker=None, verbose=False): | |
""" | |
neuspell_correct - correct a string using neuspell. | |
note that modificaitons to the checker are needed if doing list-based corrections | |
Parameters | |
---------- | |
input_text : str, required, input string to be corrected | |
checker : neuspell.NeuSpell, optional, neuspell checker object. Defaults to None. | |
verbose : bool, optional, whether to print the corrected string. Defaults to False. | |
Returns | |
------- | |
str, corrected string | |
""" | |
if isinstance(input_text, str) and len(input_text) < 4: | |
print(f"input text of {input_text} is too short to be corrected") | |
return input_text | |
if checker is None: | |
print("NOTE - no checker provided, loading default checker") | |
checker = SclstmChecker(pretrained=True) | |
corrected = checker.correct(input_text) | |
cleaned_txt = fix_punct_spaces(corrected) | |
if verbose: | |
print(f"neuspell correction result: \n\t{cleaned_txt}\n") | |
return cleaned_txt | |
def grammarpipe(corrector, qphrase: str): | |
""" | |
gramformer_correct - THE ORIGINAL ONE USED IN PROJECT AND NEEDS TO BE CHANGED. | |
Idea is to correct a string using a text2textgen pipeline model from transformers | |
Args: | |
corrector (transformers.pipeline): [transformers pipeline object, already created w/ relevant model] | |
qphrase (str): [text to be corrected] | |
Returns: | |
[str]: [corrected text] | |
""" | |
if isinstance(qphrase, str) and len(qphrase) < 4: | |
print(f"input text of {qphrase} is too short to be corrected") | |
return qphrase | |
try: | |
corrected = corrector( | |
clean(qphrase), return_text=True, clean_up_tokenization_spaces=True | |
) | |
return corrected[0]["generated_text"] | |
except Exception as e: | |
print(f"NOTE - failed to correct with grammarpipe:\n {e}") | |
return clean(qphrase) | |
def DLA_correct(qphrase: str): | |
""" | |
DLA_correct - an "overhead" function to call correct_grammar() on a string, allowing for each newline to be corrected individually | |
Args: | |
qphrase (str): [string to be corrected] | |
Returns: | |
str, the list of the corrected strings joined under " " | |
""" | |
if isinstance(qphrase, str) and len(qphrase) < 4: | |
print(f"input text of {qphrase} is too short to be corrected") | |
return qphrase | |
sentences = split_sentences(qphrase) | |
if len(sentences) == 1: | |
corrected = correct_grammar(sentences[0]) | |
return corrected | |
else: | |
full_cor = [] | |
for sen in sentences: | |
corr_sen = correct_grammar(clean(sen)) | |
full_cor.append(corr_sen) | |
return " ".join(full_cor) | |
def correct_grammar( | |
input_text: str, | |
tokenizer, | |
model, | |
n_results: int = 1, | |
beams: int = 8, | |
temp=1, | |
no_repeat_ngram_size=4, | |
rep_penalty=2.5, | |
device="cpu", | |
): | |
""" | |
correct_grammar - correct a string using a text2textgen pipeline model from transformers. | |
This function is an alternative to the t5b_correction function. | |
Parameters | |
---------- | |
input_text : str, required, input string to be corrected | |
tokenizer : transformers.T5Tokenizer, required, tokenizer object, already created w/ relevant model | |
model : transformers.T5ForConditionalGeneration, required, model object, already created w/ relevant model | |
n_results : int, optional, number of results to return. Defaults to 1. | |
beams : int, optional, number of beams to use for the correction. Defaults to 8. | |
temp : int, optional, temperature to use for the correction. Defaults to 1. | |
uniq_ngrams : int, optional, number of ngrams to use for the correction. Defaults to 2. | |
rep_penalty : float, optional, penalty to use for the correction. Defaults to 1.5. | |
device : str, optional, device to use for the correction. Defaults to 'cpu'. | |
Returns | |
------- | |
str, corrected string (or list of strings if n_results > 1) | |
""" | |
st = time.perf_counter() | |
if len(tokenizer(input_text).input_ids) < 4: | |
logging.info(f"input text of {input_text} is too short to be corrected") | |
return input_text | |
max_length = min(int(math.ceil(len(input_text) * 1.2)), 128) | |
batch = tokenizer( | |
[input_text], | |
truncation=True, | |
padding="max_length", | |
max_length=max_length, | |
return_tensors="pt", | |
).to(device) | |
translated = model.generate( | |
**batch, | |
max_length=max_length, | |
min_length=min(10, len(input_text)), | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
repetition_penalty=rep_penalty, | |
num_beams=beams, | |
num_return_sequences=n_results, | |
temperature=temp, | |
) | |
tgt_text = tokenizer.batch_decode(translated) | |
rt_min = (time.perf_counter() - st) / 60 | |
print(f"\n\ncorrected in {rt_min} minutes") | |
if isinstance(tgt_text, list): | |
return tgt_text[0] | |
else: | |
return tgt_text | |