3v324v23 commited on
Commit
54d3b67
·
1 Parent(s): 9710d79

Initial commit

Browse files
Files changed (3) hide show
  1. InferenceServer.py +67 -1
  2. app.py +211 -0
  3. custom_req.txt +14 -0
InferenceServer.py CHANGED
@@ -1 +1,67 @@
1
- print("hello world")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+ from fastapi import File
3
+ from fastapi import FastAPI
4
+ from fastapi import UploadFile
5
+ import torch
6
+ import os
7
+ import sys
8
+ import glob
9
+ import transformers
10
+ from transformers import AutoTokenizer
11
+ from transformers import AutoModelForSeq2SeqLM
12
+ from lm_scorer.models.auto import AutoLMScorer as LMScorer
13
+
14
+
15
+ print("Loading models...")
16
+ app = FastAPI()
17
+
18
+ device = "cpu"
19
+ batch_size = 1
20
+ scorer = LMScorer.from_pretrained("gpt2", device=device, batch_size=batch_size)
21
+ correction_model_tag = "prithivida/grammar_error_correcter_v2"
22
+ correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_tag)
23
+ correction_model = AutoModelForSeq2SeqLM.from_pretrained(correction_model_tag)
24
+
25
+ def set_seed(seed):
26
+ torch.manual_seed(seed)
27
+ if torch.cuda.is_available():
28
+ torch.cuda.manual_seed_all(seed)
29
+
30
+ print("Models loaded !")
31
+
32
+
33
+ @app.get("/")
34
+ def read_root():
35
+ return {"Gramformer !"}
36
+
37
+ @app.get("/{correct}")
38
+ def get_correction(input_sentence):
39
+ set_seed(1212)
40
+ scored_corrected_sentence = correct(input_sentence)
41
+ return {"scored_corrected_sentence": scored_corrected_sentence}
42
+
43
+ def correct(input_sentence, max_candidates=1):
44
+ correction_prefix = "gec: "
45
+ input_sentence = correction_prefix + input_sentence
46
+ input_ids = correction_tokenizer.encode(input_sentence, return_tensors='pt')
47
+ input_ids = input_ids.to(device)
48
+
49
+ preds = correction_model.generate(
50
+ input_ids,
51
+ do_sample=True,
52
+ max_length=128,
53
+ top_k=50,
54
+ top_p=0.95,
55
+ early_stopping=True,
56
+ num_return_sequences=max_candidates)
57
+
58
+ corrected = set()
59
+ for pred in preds:
60
+ corrected.add(correction_tokenizer.decode(pred, skip_special_tokens=True).strip())
61
+
62
+ corrected = list(corrected)
63
+ scores = scorer.sentence_score(corrected, log=True)
64
+ ranked_corrected = [(c,s) for c, s in zip(corrected, scores)]
65
+ ranked_corrected.sort(key = lambda x:x[1], reverse=True)
66
+ return ranked_corrected
67
+
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from annotated_text import annotated_text
2
+ from bs4 import BeautifulSoup
3
+ from multiprocessing import Process
4
+ import streamlit as st
5
+ import pandas as pd
6
+ import torch
7
+ import math
8
+ import re
9
+ import time
10
+ import json
11
+ import os
12
+ import requests
13
+ import spacy
14
+ import errant
15
+
16
+
17
+
18
+ def start_server():
19
+ os.system("cat custom_req.txt | xargs -n 1 -L 1 pip install -U")
20
+ os.system("uvicorn InferenceServer:app --port 8080 --host 0.0.0.0 --workers 1")
21
+
22
+ def load_models():
23
+ if not is_port_in_use(8080):
24
+ with st.spinner(text="Loading models, please wait..."):
25
+ proc = Process(target=start_server, args=(), daemon=True)
26
+ proc.start()
27
+ while not is_port_in_use(8080):
28
+ time.sleep(1)
29
+ st.success("Model server started.")
30
+ else:
31
+ st.success("Model server already running...")
32
+ st.session_state['models_loaded'] = True
33
+
34
+ def is_port_in_use(port):
35
+ import socket
36
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
37
+ return s.connect_ex(('0.0.0.0', port)) == 0
38
+
39
+ if 'models_loaded' not in st.session_state:
40
+ st.session_state['models_loaded'] = False
41
+
42
+
43
+ def show_highlights(input_text, corrected_sentence):
44
+ """
45
+ To show highlights
46
+ """
47
+ try:
48
+ strikeout = lambda x: '\u0336'.join(x) + '\u0336'
49
+ highlight_text = highlight(input_text, corrected_sentence)
50
+ color_map = {'d':'#faa', 'a':'#afa', 'c':'#fea'}
51
+ tokens = re.split(r'(<[dac]\s.*?<\/[dac]>)', highlight_text)
52
+ annotations = []
53
+ for token in tokens:
54
+ soup = BeautifulSoup(token, 'html.parser')
55
+ tags = soup.findAll()
56
+ if tags:
57
+ _tag = tags[0].name
58
+ _type = tags[0]['type']
59
+ _text = tags[0]['edit']
60
+ _color = color_map[_tag]
61
+
62
+ if _tag == 'd':
63
+ _text = strikeout(tags[0].text)
64
+
65
+ annotations.append((_text, _type, _color))
66
+ else:
67
+ annotations.append(token)
68
+ annotated_text(*annotations)
69
+ except Exception as e:
70
+ st.error('Some error occured!' + str(e))
71
+ st.stop()
72
+
73
+ def show_edits(input_text, corrected_sentence):
74
+ """
75
+ To show edits
76
+ """
77
+ try:
78
+ edits = get_edits(input_text, corrected_sentence)
79
+ df = pd.DataFrame(edits, columns=['type','original word', 'original start', 'original end', 'correct word', 'correct start', 'correct end'])
80
+ df = df.set_index('type')
81
+ st.table(df)
82
+ except Exception as e:
83
+ st.error('Some error occured!')
84
+ st.stop()
85
+
86
+ def highlight(orig, cor):
87
+ edits = _get_edits(orig, cor)
88
+ orig_tokens = orig.split()
89
+
90
+ ignore_indexes = []
91
+
92
+ for edit in edits:
93
+ edit_type = edit[0]
94
+ edit_str_start = edit[1]
95
+ edit_spos = edit[2]
96
+ edit_epos = edit[3]
97
+ edit_str_end = edit[4]
98
+
99
+ # if no_of_tokens(edit_str_start) > 1 ==> excluding the first token, mark all other tokens for deletion
100
+ for i in range(edit_spos+1, edit_epos):
101
+ ignore_indexes.append(i)
102
+
103
+ if edit_str_start == "":
104
+ if edit_spos - 1 >= 0:
105
+ new_edit_str = orig_tokens[edit_spos - 1]
106
+ edit_spos -= 1
107
+ else:
108
+ new_edit_str = orig_tokens[edit_spos + 1]
109
+ edit_spos += 1
110
+ if edit_type == "PUNCT":
111
+ st = "<a type='" + edit_type + "' edit='" + \
112
+ edit_str_end + "'>" + new_edit_str + "</a>"
113
+ else:
114
+ st = "<a type='" + edit_type + "' edit='" + new_edit_str + \
115
+ " " + edit_str_end + "'>" + new_edit_str + "</a>"
116
+ orig_tokens[edit_spos] = st
117
+ elif edit_str_end == "":
118
+ st = "<d type='" + edit_type + "' edit=''>" + edit_str_start + "</d>"
119
+ orig_tokens[edit_spos] = st
120
+ else:
121
+ st = "<c type='" + edit_type + "' edit='" + \
122
+ edit_str_end + "'>" + edit_str_start + "</c>"
123
+ orig_tokens[edit_spos] = st
124
+
125
+ for i in sorted(ignore_indexes, reverse=True):
126
+ del(orig_tokens[i])
127
+
128
+ return(" ".join(orig_tokens))
129
+
130
+
131
+ def _get_edits(orig, cor):
132
+ orig = annotator.parse(orig)
133
+ cor = annotator.parse(cor)
134
+ alignment = annotator.align(orig, cor)
135
+ edits = annotator.merge(alignment)
136
+
137
+ if len(edits) == 0:
138
+ return []
139
+
140
+ edit_annotations = []
141
+ for e in edits:
142
+ e = annotator.classify(e)
143
+ edit_annotations.append((e.type[2:], e.o_str, e.o_start, e.o_end, e.c_str, e.c_start, e.c_end))
144
+
145
+ if len(edit_annotations) > 0:
146
+ return edit_annotations
147
+ else:
148
+ return []
149
+
150
+ def get_edits(orig, cor):
151
+ return _get_edits(orig, cor)
152
+
153
+ def get_correction(input_text):
154
+ correct_request = "http://0.0.0.0:8080/correct?input_sentence="+input_text
155
+ correct_response = requests.get(correct_request)
156
+ correct_json = json.loads(correct_response.text)
157
+ scored_corrected_sentence = correct_json["scored_corrected_sentence"]
158
+
159
+ corrected_sentence, score = scored_corrected_sentence[0]
160
+ st.markdown(f'##### Corrected text:')
161
+ st.write('')
162
+ st.success(corrected_sentence)
163
+ exp1 = st.expander(label='Show highlights', expanded=True)
164
+ with exp1:
165
+ show_highlights(input_text, corrected_sentence)
166
+ exp2 = st.expander(label='Show edits')
167
+ with exp2:
168
+ show_edits(input_text, corrected_sentence)
169
+
170
+
171
+ if __name__ == "__main__":
172
+ if not st.session_state['models_loaded']:
173
+ load_models()
174
+
175
+
176
+ st.title('Gramformer')
177
+ st.subheader('A framework for correcting english grammatical errors')
178
+ st.markdown("Built for fun with 💙 by a quintessential foodie - Prithivi Da, The maker of [WhatTheFood](https://huggingface.co/spaces/prithivida/WhatTheFood), [Styleformer](https://github.com/PrithivirajDamodaran/Styleformer) and [Parrot paraphraser](https://github.com/PrithivirajDamodaran/Parrot_Paraphraser) | ✍️ [@prithivida](https://twitter.com/prithivida) |[[GitHub]](https://github.com/PrithivirajDamodaran)", unsafe_allow_html=True)
179
+
180
+ examples = [
181
+ "what be the reason for everyone leave the comapny",
182
+ "He are moving here.",
183
+ "I am doing fine. How is you?",
184
+ "How is they?",
185
+ "Matt like fish",
186
+ "the collection of letters was original used by the ancient Romans",
187
+ "We enjoys horror movies",
188
+ "Anna and Mike is going skiing",
189
+ "I walk to the store and I bought milk",
190
+ " We all eat the fish and then made dessert",
191
+ "I will eat fish for dinner and drink milk",
192
+ ]
193
+
194
+ nlp = spacy.load('en_core_web_sm')
195
+ annotator = errant.load('en', nlp)
196
+
197
+ input_text = st.selectbox(
198
+ label="Choose an example",
199
+ options=examples
200
+ )
201
+ st.write("(or)")
202
+ input_text = st.text_input(
203
+ label="Enter your own text",
204
+ value=input_text
205
+ )
206
+
207
+ if input_text.strip():
208
+ get_correction(input_text)
209
+
210
+
211
+
custom_req.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ st-annotated-text
2
+ bs4
3
+ torch
4
+ fastapi
5
+ uvicorn
6
+ spacy==2.3.0
7
+ python-Levenshtein==0.12.2
8
+ errant==2.2.0
9
+ lm-scorer==0.4.2
10
+ fsspec==2021.5.0
11
+ tokenizers
12
+ fuzzywuzzy==0.18.0
13
+ sentencepiece==0.1.95
14
+ transformers