Spaces:
Runtime error
Runtime error
Olivia Figueira
commited on
Commit
•
37d028c
1
Parent(s):
e3d0258
Added other LMs to demo
Browse files- critic/critic.py +112 -20
- requirements.txt +6 -3
critic/critic.py
CHANGED
@@ -5,24 +5,19 @@ import hashlib
|
|
5 |
import numpy as np
|
6 |
from tqdm import tqdm
|
7 |
from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel
|
|
|
|
|
|
|
|
|
8 |
import nltk
|
|
|
9 |
nltk.download('punkt')
|
10 |
|
11 |
sys.path.insert(0, '.')
|
12 |
from critic.perturbations import get_local_neighbors_char_level, get_local_neighbors_word_level
|
13 |
from utils.spacy_tokenizer import spacy_tokenize_gec
|
14 |
|
15 |
-
|
16 |
-
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
17 |
-
tokenizer.pad_token = tokenizer.eos_token
|
18 |
-
model = GPT2LMHeadModel.from_pretrained(model_name)
|
19 |
-
model.eval()
|
20 |
-
#model.cuda()
|
21 |
-
model.cpu()
|
22 |
-
print (f'Loaded {model_name}')
|
23 |
-
|
24 |
-
|
25 |
-
def get_gpt2_loss(input_ids, attention_mask, labels):
|
26 |
with torch.no_grad():
|
27 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
28 |
lm_logits = outputs[1] #[bsize, seqlen, vocab]
|
@@ -39,7 +34,7 @@ def get_gpt2_loss(input_ids, attention_mask, labels):
|
|
39 |
|
40 |
MAX_LENGTH = 66
|
41 |
|
42 |
-
def run_gpt2(sents, cuda=False, model_name=None):
|
43 |
assert isinstance(sents, list)
|
44 |
_sents = [tokenizer.bos_token + s for s in sents]
|
45 |
inputs = tokenizer(_sents, return_tensors="pt", padding=True)
|
@@ -47,7 +42,7 @@ def run_gpt2(sents, cuda=False, model_name=None):
|
|
47 |
return None
|
48 |
if cuda:
|
49 |
inputs = {k: v.cuda() for k, v in inputs.items()}
|
50 |
-
loss = get_gpt2_loss(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=inputs['input_ids'])
|
51 |
logps = - loss.detach().cpu()
|
52 |
return logps
|
53 |
|
@@ -95,7 +90,7 @@ def gpt2_critic_char_level_only(sent, verbose=1, cuda=False, fp16=True, seed='au
|
|
95 |
return is_good, float(logps[0]), counter_example
|
96 |
|
97 |
|
98 |
-
def gpt2_critic(sent, verbose=1, cuda=False, fp16=True, seed='auto', n_samples=100, word_level_mode='refine'):
|
99 |
return_string = []
|
100 |
if seed == 'auto':
|
101 |
seed = int(hashlib.md5(sent.encode()).hexdigest(), 16) % (2**32) #Seed must be between 0 and 2**32 - 1
|
@@ -116,9 +111,9 @@ def gpt2_critic(sent, verbose=1, cuda=False, fp16=True, seed='auto', n_samples=1
|
|
116 |
sents = [orig_sent] + list(sent_perturbations_c.union(sent_perturbations_w))
|
117 |
if fp16:
|
118 |
with torch.cuda.amp.autocast():
|
119 |
-
logps = run_gpt2(sents, cuda)
|
120 |
else:
|
121 |
-
logps = run_gpt2(sents, cuda)
|
122 |
if logps is None:
|
123 |
if verbose:
|
124 |
print ('Invalid input. Maybe the sentence is too long.')
|
@@ -147,11 +142,108 @@ def main():
|
|
147 |
import streamlit as st
|
148 |
st.subheader('Exploring Unsupervised Grammatical Error Correction with Transformer-Based Models')
|
149 |
sent = st.text_input('Enter a sentence:', value="")
|
|
|
|
|
150 |
if sent != '':
|
151 |
-
st.markdown(f"**Sentence**: {sent}")
|
152 |
-
|
153 |
-
|
154 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
if __name__ == '__main__':
|
157 |
main()
|
|
|
5 |
import numpy as np
|
6 |
from tqdm import tqdm
|
7 |
from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel
|
8 |
+
from transformers import OPTForCausalLM, GPTNeoForCausalLM
|
9 |
+
from transformers import RobertaTokenizer, RobertaForCausalLM, RobertaConfig
|
10 |
+
from transformers import XLMRobertaTokenizer, XLMRobertaForCausalLM, XLMRobertaConfig
|
11 |
+
from transformers import BartTokenizer, BartForCausalLM
|
12 |
import nltk
|
13 |
+
import pandas as pd
|
14 |
nltk.download('punkt')
|
15 |
|
16 |
sys.path.insert(0, '.')
|
17 |
from critic.perturbations import get_local_neighbors_char_level, get_local_neighbors_word_level
|
18 |
from utils.spacy_tokenizer import spacy_tokenize_gec
|
19 |
|
20 |
+
def get_gpt2_loss(model, tokenizer, input_ids, attention_mask, labels):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
with torch.no_grad():
|
22 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
23 |
lm_logits = outputs[1] #[bsize, seqlen, vocab]
|
|
|
34 |
|
35 |
MAX_LENGTH = 66
|
36 |
|
37 |
+
def run_gpt2(sents, model, tokenizer, cuda=False, model_name=None):
|
38 |
assert isinstance(sents, list)
|
39 |
_sents = [tokenizer.bos_token + s for s in sents]
|
40 |
inputs = tokenizer(_sents, return_tensors="pt", padding=True)
|
|
|
42 |
return None
|
43 |
if cuda:
|
44 |
inputs = {k: v.cuda() for k, v in inputs.items()}
|
45 |
+
loss = get_gpt2_loss(model, tokenizer, input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=inputs['input_ids'])
|
46 |
logps = - loss.detach().cpu()
|
47 |
return logps
|
48 |
|
|
|
90 |
return is_good, float(logps[0]), counter_example
|
91 |
|
92 |
|
93 |
+
def gpt2_critic(sent, model, tokenizer, verbose=1, cuda=False, fp16=True, seed='auto', n_samples=100, word_level_mode='refine'):
|
94 |
return_string = []
|
95 |
if seed == 'auto':
|
96 |
seed = int(hashlib.md5(sent.encode()).hexdigest(), 16) % (2**32) #Seed must be between 0 and 2**32 - 1
|
|
|
111 |
sents = [orig_sent] + list(sent_perturbations_c.union(sent_perturbations_w))
|
112 |
if fp16:
|
113 |
with torch.cuda.amp.autocast():
|
114 |
+
logps = run_gpt2(sents, model, tokenizer, cuda)
|
115 |
else:
|
116 |
+
logps = run_gpt2(sents, model, tokenizer, cuda)
|
117 |
if logps is None:
|
118 |
if verbose:
|
119 |
print ('Invalid input. Maybe the sentence is too long.')
|
|
|
142 |
import streamlit as st
|
143 |
st.subheader('Exploring Unsupervised Grammatical Error Correction with Transformer-Based Models')
|
144 |
sent = st.text_input('Enter a sentence:', value="")
|
145 |
+
|
146 |
+
### LMs we are trying:
|
147 |
if sent != '':
|
148 |
+
st.markdown(f"**Input Sentence**: {sent}")
|
149 |
+
results = {}
|
150 |
+
|
151 |
+
with st.spinner('Running with GPT-2 LM...'):
|
152 |
+
## GPT-2 LM (original LM-critic)
|
153 |
+
model_name = 'gpt2'
|
154 |
+
nice_name = "GPT-2"
|
155 |
+
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
156 |
+
tokenizer.pad_token = tokenizer.eos_token
|
157 |
+
model = GPT2LMHeadModel.from_pretrained(model_name)
|
158 |
+
model.eval()
|
159 |
+
model.cpu()
|
160 |
+
is_good, score, counter_example, return_string_GPT2 = gpt2_critic(sent, model, tokenizer)
|
161 |
+
st.markdown("**Results with GPT-2 LM:**")
|
162 |
+
st.write('\n'.join(return_string_GPT2))
|
163 |
+
results[nice_name] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
|
164 |
+
|
165 |
+
with st.spinner('Running with OPT LM...'):
|
166 |
+
## OPT LM
|
167 |
+
model_name = "facebook/opt-350m"
|
168 |
+
nice_name = "OPT"
|
169 |
+
model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
|
170 |
+
tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
|
171 |
+
tokenizer.pad_token = tokenizer.eos_token
|
172 |
+
model.eval()
|
173 |
+
model.cpu()
|
174 |
+
is_good, score, counter_example, return_string_OPT = gpt2_critic(sent, model, tokenizer)
|
175 |
+
st.markdown("**Results with OPT LM:**")
|
176 |
+
st.write('\n'.join(return_string_OPT))
|
177 |
+
results[nice_name] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
|
178 |
+
|
179 |
+
with st.spinner('Running with GPT NEO LM...'):
|
180 |
+
## GPT NEO
|
181 |
+
model_name = "EleutherAI/gpt-neo-1.3B"
|
182 |
+
nice_name = "GPT NEO"
|
183 |
+
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
|
184 |
+
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
|
185 |
+
tokenizer.pad_token = tokenizer.eos_token
|
186 |
+
model.eval()
|
187 |
+
model.cpu()
|
188 |
+
is_good, score, counter_example, return_string_GPTNEO = gpt2_critic(sent, model, tokenizer)
|
189 |
+
st.markdown("**Results with GPT NEO LM:**")
|
190 |
+
st.write('\n'.join(return_string_GPTNEO))
|
191 |
+
results[nice_name] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
|
192 |
+
|
193 |
+
with st.spinner('Running with RoBERTa LM...'):
|
194 |
+
## RoBERTa
|
195 |
+
model_name = "roberta-base"
|
196 |
+
nice_name = "RoBERTa"
|
197 |
+
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
198 |
+
config = RobertaConfig.from_pretrained("roberta-base")
|
199 |
+
config.is_decoder = True
|
200 |
+
model = RobertaForCausalLM.from_pretrained("roberta-base", config=config)
|
201 |
+
tokenizer.pad_token = tokenizer.eos_token
|
202 |
+
model.eval()
|
203 |
+
model.cpu()
|
204 |
+
is_good, score, counter_example, return_string_RoBERTa = gpt2_critic(sent, model, tokenizer)
|
205 |
+
st.markdown("**Results with RoBERTa LM:**")
|
206 |
+
st.write('\n'.join(return_string_RoBERTa))
|
207 |
+
results[nice_name] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
|
208 |
+
|
209 |
+
with st.spinner('Running with BART LM...'):
|
210 |
+
## RoBERTa
|
211 |
+
model_name = "facebook/bart-base"
|
212 |
+
nice_name = "BART"
|
213 |
+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
214 |
+
model = BartForCausalLM.from_pretrained("facebook/bart-base", add_cross_attention=False)
|
215 |
+
assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
|
216 |
+
tokenizer.pad_token = tokenizer.eos_token
|
217 |
+
model.eval()
|
218 |
+
model.cpu()
|
219 |
+
is_good, score, counter_example, return_string_BART = gpt2_critic(sent, model, tokenizer)
|
220 |
+
st.markdown("**Results with BART LM:**")
|
221 |
+
st.write('\n'.join(return_string_BART))
|
222 |
+
results[nice_name] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
|
223 |
+
|
224 |
+
with st.spinner('Running with XLM RoBERTa LM...'):
|
225 |
+
## XLM RoBERTa
|
226 |
+
model_name = 'xlm-roberta-base'
|
227 |
+
nice_name = 'XLM RoBERTa'
|
228 |
+
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
|
229 |
+
config = XLMRobertaConfig.from_pretrained("xlm-roberta-base")
|
230 |
+
config.is_decoder = True
|
231 |
+
model = XLMRobertaForCausalLM.from_pretrained("xlm-roberta-base", config=config)
|
232 |
+
tokenizer.pad_token = tokenizer.eos_token
|
233 |
+
model.eval()
|
234 |
+
model.cpu()
|
235 |
+
is_good, score, counter_example, return_string_XLMRoBERTa = gpt2_critic(sent, model, tokenizer)
|
236 |
+
st.markdown("**Results with XLM RoBERTa LM:**")
|
237 |
+
st.write('\n'.join(return_string_XLMRoBERTa))
|
238 |
+
results[nice_name] = ["Good" if is_good else "Bad", str(round(score, 3)), "N/A" if not counter_example else str(counter_example[0]), "N/A" if not counter_example else str(round(counter_example[1], 3))]
|
239 |
+
|
240 |
+
df = pd.DataFrame.from_dict(results,
|
241 |
+
orient = 'index',
|
242 |
+
columns=['Judgement', 'Score (log(p))', 'Neighbor sentence with highest score (log(p))', 'Neighbor sentence score (log(p))'])
|
243 |
+
st.markdown("**Tabular summary of results:**")
|
244 |
+
st.table(df)
|
245 |
+
|
246 |
+
st.write("Done.")
|
247 |
|
248 |
if __name__ == '__main__':
|
249 |
main()
|
requirements.txt
CHANGED
@@ -1,9 +1,12 @@
|
|
1 |
-
datasets==
|
2 |
editdistance==0.6.0
|
3 |
nltk==3.7
|
4 |
numpy==1.22.3
|
5 |
spacy==3.0.5
|
6 |
streamlit==1.9.0
|
7 |
torch==1.11.0
|
8 |
-
tqdm==4.
|
9 |
-
transformers==4.
|
|
|
|
|
|
|
|
1 |
+
datasets==2.2.2
|
2 |
editdistance==0.6.0
|
3 |
nltk==3.7
|
4 |
numpy==1.22.3
|
5 |
spacy==3.0.5
|
6 |
streamlit==1.9.0
|
7 |
torch==1.11.0
|
8 |
+
tqdm==4.62.1
|
9 |
+
transformers==4.19.2
|
10 |
+
protobuf~=3.19.0
|
11 |
+
sentencepiece==0.1.96
|
12 |
+
huggingface-hub==0.1.0
|