JunchuanYu commited on
Commit
dee3f71
1 Parent(s): d6353f2

Upload gramformer.py

Browse files
Files changed (1) hide show
  1. gramformer.py +128 -0
gramformer.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Gramformer:
2
+
3
+ def __init__(self, models=1, use_gpu=False):
4
+ from transformers import AutoTokenizer
5
+ from transformers import AutoModelForSeq2SeqLM
6
+ #from lm_scorer.models.auto import AutoLMScorer as LMScorer
7
+ import errant
8
+ self.annotator = errant.load('en')
9
+
10
+ if use_gpu:
11
+ device= "cuda:0"
12
+ else:
13
+ device = "cpu"
14
+ batch_size = 1
15
+ #self.scorer = LMScorer.from_pretrained("gpt2", device=device, batch_size=batch_size)
16
+ self.device = device
17
+ correction_model_tag = "prithivida/grammar_error_correcter_v1"
18
+ self.model_loaded = False
19
+
20
+ if models == 1:
21
+ self.correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_tag, use_auth_token=False)
22
+ self.correction_model = AutoModelForSeq2SeqLM.from_pretrained(correction_model_tag, use_auth_token=False)
23
+ self.correction_model = self.correction_model.to(device)
24
+ self.model_loaded = True
25
+ print("[Gramformer] Grammar error correct/highlight model loaded..")
26
+ elif models == 2:
27
+ # TODO
28
+ print("TO BE IMPLEMENTED!!!")
29
+
30
+ def correct(self, input_sentence, max_candidates=1):
31
+ if self.model_loaded:
32
+ correction_prefix = "gec: "
33
+ input_sentence = correction_prefix + input_sentence
34
+ input_ids = self.correction_tokenizer.encode(input_sentence, return_tensors='pt')
35
+ input_ids = input_ids.to(self.device)
36
+
37
+ preds = self.correction_model.generate(
38
+ input_ids,
39
+ do_sample=True,
40
+ max_length=128,
41
+ # top_k=50,
42
+ # top_p=0.95,
43
+ num_beams=7,
44
+ early_stopping=True,
45
+ num_return_sequences=max_candidates)
46
+
47
+ corrected = set()
48
+ for pred in preds:
49
+ corrected.add(self.correction_tokenizer.decode(pred, skip_special_tokens=True).strip())
50
+
51
+ #corrected = list(corrected)
52
+ #scores = self.scorer.sentence_score(corrected, log=True)
53
+ #ranked_corrected = [(c,s) for c, s in zip(corrected, scores)]
54
+ #ranked_corrected.sort(key = lambda x:x[1], reverse=True)
55
+ return corrected
56
+ else:
57
+ print("Model is not loaded")
58
+ return None
59
+
60
+ def highlight(self, orig, cor):
61
+ edits = self._get_edits(orig, cor)
62
+ orig_tokens = orig.split()
63
+
64
+ ignore_indexes = []
65
+
66
+ for edit in edits:
67
+ edit_type = edit[0]
68
+ edit_str_start = edit[1]
69
+ edit_spos = edit[2]
70
+ edit_epos = edit[3]
71
+ edit_str_end = edit[4]
72
+
73
+ # if no_of_tokens(edit_str_start) > 1 ==> excluding the first token, mark all other tokens for deletion
74
+ for i in range(edit_spos+1, edit_epos):
75
+ ignore_indexes.append(i)
76
+
77
+ if edit_str_start == "":
78
+ if edit_spos - 1 >= 0:
79
+ new_edit_str = orig_tokens[edit_spos - 1]
80
+ edit_spos -= 1
81
+ else:
82
+ new_edit_str = orig_tokens[edit_spos + 1]
83
+ edit_spos += 1
84
+ if edit_type == "PUNCT":
85
+ st = "<a type='" + edit_type + "' edit='" + \
86
+ edit_str_end + "'>" + new_edit_str + "</a>"
87
+ else:
88
+ st = "<a type='" + edit_type + "' edit='" + new_edit_str + \
89
+ " " + edit_str_end + "'>" + new_edit_str + "</a>"
90
+ orig_tokens[edit_spos] = st
91
+ elif edit_str_end == "":
92
+ st = "<d type='" + edit_type + "' edit=''>" + edit_str_start + "</d>"
93
+ orig_tokens[edit_spos] = st
94
+ else:
95
+ st = "<c type='" + edit_type + "' edit='" + \
96
+ edit_str_end + "'>" + edit_str_start + "</c>"
97
+ orig_tokens[edit_spos] = st
98
+
99
+ for i in sorted(ignore_indexes, reverse=True):
100
+ del(orig_tokens[i])
101
+
102
+ return(" ".join(orig_tokens))
103
+
104
+ def detect(self, input_sentence):
105
+ # TO BE IMPLEMENTED
106
+ pass
107
+
108
+ def _get_edits(self, orig, cor):
109
+ orig = self.annotator.parse(orig)
110
+ cor = self.annotator.parse(cor)
111
+ alignment = self.annotator.align(orig, cor)
112
+ edits = self.annotator.merge(alignment)
113
+
114
+ if len(edits) == 0:
115
+ return []
116
+
117
+ edit_annotations = []
118
+ for e in edits:
119
+ e = self.annotator.classify(e)
120
+ edit_annotations.append((e.type[2:], e.o_str, e.o_start, e.o_end, e.c_str, e.c_start, e.c_end))
121
+
122
+ if len(edit_annotations) > 0:
123
+ return edit_annotations
124
+ else:
125
+ return []
126
+
127
+ def get_edits(self, orig, cor):
128
+ return self._get_edits(orig, cor)