|
class Gramformer: |
|
|
|
def __init__(self, models=1, use_gpu=False): |
|
from transformers import AutoTokenizer |
|
from transformers import AutoModelForSeq2SeqLM |
|
from lm_scorer.models.auto import AutoLMScorer as LMScorer |
|
import errant |
|
import en_core_web_sm |
|
nlp = en_core_web_sm.load() |
|
self.annotator = errant.load('en', nlp) |
|
|
|
if use_gpu: |
|
device= "cuda:0" |
|
else: |
|
device = "cpu" |
|
batch_size = 1 |
|
self.scorer = LMScorer.from_pretrained("gpt2", device=device, batch_size=batch_size) |
|
self.device = device |
|
correction_model_tag = "prithivida/grammar_error_correcter_v1" |
|
self.model_loaded = False |
|
|
|
if models == 1: |
|
self.correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_tag) |
|
self.correction_model = AutoModelForSeq2SeqLM.from_pretrained(correction_model_tag) |
|
self.correction_model = self.correction_model.to(device) |
|
self.model_loaded = True |
|
print("[Gramformer] Grammar error correct/highlight model loaded..") |
|
elif models == 2: |
|
|
|
print("TO BE IMPLEMENTED!!!") |
|
|
|
def correct(self, input_sentence, max_candidates=1): |
|
if self.model_loaded: |
|
correction_prefix = "gec: " |
|
input_sentence = correction_prefix + input_sentence |
|
input_ids = self.correction_tokenizer.encode(input_sentence, return_tensors='pt') |
|
input_ids = input_ids.to(self.device) |
|
|
|
preds = self.correction_model.generate( |
|
input_ids, |
|
do_sample=True, |
|
max_length=128, |
|
top_k=50, |
|
top_p=0.95, |
|
early_stopping=True, |
|
num_return_sequences=max_candidates) |
|
|
|
corrected = set() |
|
for pred in preds: |
|
corrected.add(self.correction_tokenizer.decode(pred, skip_special_tokens=True).strip()) |
|
|
|
corrected = list(corrected) |
|
scores = self.scorer.sentence_score(corrected, log=True) |
|
ranked_corrected = [(c,s) for c, s in zip(corrected, scores)] |
|
ranked_corrected.sort(key = lambda x:x[1], reverse=True) |
|
return ranked_corrected |
|
else: |
|
print("Model is not loaded") |
|
return None |
|
|
|
def highlight(self, orig, cor): |
|
edits = self._get_edits(orig, cor) |
|
orig_tokens = orig.split() |
|
|
|
ignore_indexes = [] |
|
|
|
for edit in edits: |
|
edit_type = edit[0] |
|
edit_str_start = edit[1] |
|
edit_spos = edit[2] |
|
edit_epos = edit[3] |
|
edit_str_end = edit[4] |
|
|
|
|
|
for i in range(edit_spos+1, edit_epos): |
|
ignore_indexes.append(i) |
|
|
|
if edit_str_start == "": |
|
if edit_spos - 1 >= 0: |
|
new_edit_str = orig_tokens[edit_spos - 1] |
|
edit_spos -= 1 |
|
else: |
|
new_edit_str = orig_tokens[edit_spos + 1] |
|
edit_spos += 1 |
|
if edit_type == "PUNCT": |
|
st = "<a type='" + edit_type + "' edit='" + \ |
|
edit_str_end + "'>" + new_edit_str + "</a>" |
|
else: |
|
st = "<a type='" + edit_type + "' edit='" + new_edit_str + \ |
|
" " + edit_str_end + "'>" + new_edit_str + "</a>" |
|
orig_tokens[edit_spos] = st |
|
elif edit_str_end == "": |
|
st = "<d type='" + edit_type + "' edit=''>" + edit_str_start + "</d>" |
|
orig_tokens[edit_spos] = st |
|
else: |
|
st = "<c type='" + edit_type + "' edit='" + \ |
|
edit_str_end + "'>" + edit_str_start + "</c>" |
|
orig_tokens[edit_spos] = st |
|
|
|
for i in sorted(ignore_indexes, reverse=True): |
|
del(orig_tokens[i]) |
|
|
|
return(" ".join(orig_tokens)) |
|
|
|
def detect(self, input_sentence): |
|
|
|
pass |
|
|
|
def _get_edits(self, orig, cor): |
|
orig = self.annotator.parse(orig) |
|
cor = self.annotator.parse(cor) |
|
alignment = self.annotator.align(orig, cor) |
|
edits = self.annotator.merge(alignment) |
|
|
|
if len(edits) == 0: |
|
return [] |
|
|
|
edit_annotations = [] |
|
for e in edits: |
|
e = self.annotator.classify(e) |
|
edit_annotations.append((e.type[2:], e.o_str, e.o_start, e.o_end, e.c_str, e.c_start, e.c_end)) |
|
|
|
if len(edit_annotations) > 0: |
|
return edit_annotations |
|
else: |
|
return [] |
|
|
|
def get_edits(self, orig, cor): |
|
return self._get_edits(orig, cor) |
|
|