Spaces:
Runtime error
Runtime error
""" | |
Word-level perturbation generator. | |
Originally by https://github.com/awasthiabhijeet/PIE/tree/master/errorify | |
""" | |
import os | |
import math | |
import pickle | |
import random | |
import editdistance | |
from numpy.random import choice as npchoice | |
from collections import defaultdict | |
try: | |
dir_path = os.path.dirname(os.path.realpath(__file__)) | |
except: | |
dir_path = '.' | |
VERBS = pickle.load(open(f'{dir_path}/verbs.p', 'rb')) | |
COMMON_INSERTS = set(pickle.load(open(f'{dir_path}/common_inserts.p', 'rb'))) #common inserts *to fix a sent* | |
COMMON_DELETES = pickle.load(open(f'{dir_path}/common_deletes.p','rb')) #common deletes *to fix a sent* | |
_COMMON_REPLACES = pickle.load(open(f'{dir_path}/common_replaces.p', 'rb')) #common replacements *to errorify a sent* | |
COMMON_REPLACES = {} | |
for src in _COMMON_REPLACES: | |
for tgt in _COMMON_REPLACES[src]: | |
if (src=="'re" and tgt=="are") or (tgt=="'re" and src=="are"): | |
continue | |
ED = editdistance.eval(tgt, src) | |
if ED > 2: | |
continue | |
longer = max(len(src), len(tgt)) | |
if float(ED)/longer >= 0.5: | |
continue | |
if tgt not in COMMON_REPLACES: | |
COMMON_REPLACES[tgt] = {} | |
COMMON_REPLACES[tgt][src] = _COMMON_REPLACES[src][tgt] | |
VERBS_refine = defaultdict(list) | |
for src in VERBS: | |
for tgt in VERBS[src]: | |
ED = editdistance.eval(tgt, src) | |
if ED > 2: | |
continue | |
longer = max(len(src), len(tgt)) | |
if float(ED)/longer >= 0.5: | |
continue | |
VERBS_refine[src].append(tgt) | |
class WordLevelPerturber_all: | |
def __init__(self, sentence: str): | |
self.original_sentence = sentence.rstrip() | |
self.sentence = self.original_sentence | |
self.tokenized = None | |
self.tokenize() | |
def tokenize(self): | |
self.tokenized = self.sentence.split() | |
def orig(self): | |
return self.original_sentence | |
def _insert(self): | |
"""Insert a commonly deleted word.""" | |
if len(self.tokenized) > 0: | |
insertable = list(range(len(self.tokenized))) | |
index = random.choice(insertable) | |
plist = list(COMMON_DELETES.values()) | |
plistsum = sum(plist) | |
plist = [x / plistsum for x in plist] | |
# Choose a word | |
ins_word = npchoice(list(COMMON_DELETES.keys()), p=plist) | |
self.tokenized.insert(index,ins_word) | |
return ' '.join(self.tokenized) | |
def _mod_verb(self, redir=True): | |
if len(self.tokenized) > 0: | |
verbs = [i for i, w in enumerate(self.tokenized) if w in VERBS] | |
if not verbs: | |
if redir: | |
return self._replace(redir=False) | |
return self.sentence | |
index = random.choice(verbs) | |
word = self.tokenized[index] | |
if not VERBS[word]: | |
return self.sentence | |
repl = random.choice(VERBS[word]) | |
self.tokenized[index] = repl | |
return ' '.join(self.tokenized) | |
def _delete(self): | |
"""Delete a commonly inserted word.""" | |
if len(self.tokenized) > 1: | |
toks_len = len(self.tokenized) | |
toks = self.tokenized | |
deletable = [i for i, w in enumerate(toks) if w in COMMON_INSERTS] | |
if not deletable: | |
return self.sentence | |
index = random.choice(deletable) | |
del self.tokenized[index] | |
return ' '.join(self.tokenized) | |
def _replace(self, redir=True): | |
if len(self.tokenized) > 0: | |
deletable = [i for i, w in enumerate(self.tokenized) if (w in COMMON_REPLACES)] | |
if not deletable: | |
if redir: | |
return self._mod_verb(redir=False) | |
return self.sentence | |
index = random.choice(deletable) | |
word = self.tokenized[index] | |
if not COMMON_REPLACES[word]: | |
return self.sentence | |
# Normalize probabilities | |
plist = list(COMMON_REPLACES[word].values()) | |
plistsum = sum(plist) | |
plist = [x / plistsum for x in plist] | |
# Choose a word | |
repl = npchoice(list(COMMON_REPLACES[word].keys()), p=plist) | |
self.tokenized[index] = repl | |
return ' '.join(self.tokenized) | |
def perturb(self): | |
count = 1 | |
orig_sent = self.sentence | |
for x in range(count): | |
perturb_probs = [.30,.30,.30,.10] | |
perturb_fun = npchoice([self._insert, self._mod_verb, self._replace, self._delete],p=perturb_probs) | |
self.sentence = perturb_fun() | |
self.tokenize() | |
res_sentence = self.sentence | |
self.sentence = self.original_sentence | |
self.tokenize() | |
return res_sentence | |
class WordLevelPerturber_refine: | |
def __init__(self, sentence: str): | |
self.original_sentence = sentence.rstrip() | |
self.sentence = self.original_sentence | |
self.tokenized = None | |
self.tokenize() | |
def tokenize(self): | |
self.tokenized = self.sentence.split() | |
def orig(self): | |
return self.original_sentence | |
def _insert(self): | |
"""Insert a commonly deleted word.""" | |
if len(self.tokenized) > 0: | |
insertable = list(range(len(self.tokenized))) | |
index = random.choice(insertable) | |
plist = list(COMMON_DELETES.values()) | |
plistsum = sum(plist) | |
plist = [x / plistsum for x in plist] | |
# Choose a word | |
ins_word = npchoice(list(COMMON_DELETES.keys()), p=plist) | |
self.tokenized.insert(index,ins_word) | |
return ' '.join(self.tokenized) | |
def _mod_verb(self, redir=True): | |
if len(self.tokenized) > 0: | |
verbs = [i for i, w in enumerate(self.tokenized) if w in VERBS_refine] | |
if not verbs: | |
if redir: | |
return self._replace(redir=False) | |
return self.sentence | |
index = random.choice(verbs) | |
word = self.tokenized[index] | |
if not VERBS_refine[word]: | |
return self.sentence | |
repl = random.choice(VERBS_refine[word]) | |
self.tokenized[index] = repl | |
return ' '.join(self.tokenized) | |
def _delete(self): | |
"""Delete a commonly inserted word.""" | |
if len(self.tokenized) > 1: | |
toks_len = len(self.tokenized) | |
toks = self.tokenized | |
deletable = [i for i, w in enumerate(toks) if (w in COMMON_INSERTS) and (i>0 and toks[i-1].lower() == toks[i].lower())] | |
if not deletable: | |
return self.sentence | |
index = random.choice(deletable) | |
del self.tokenized[index] | |
return ' '.join(self.tokenized) | |
def _replace(self, redir=True): | |
def _keep(i,w): | |
if w.lower() in {"not", "n't"}: | |
return True | |
return False | |
if len(self.tokenized) > 0: | |
deletable = [i for i, w in enumerate(self.tokenized) if (w in COMMON_REPLACES) and (not _keep(i,w))] | |
if not deletable: | |
if redir: | |
return self._mod_verb(redir=False) | |
return self.sentence | |
index = random.choice(deletable) | |
word = self.tokenized[index] | |
if not COMMON_REPLACES[word]: | |
return self.sentence | |
# Normalize probabilities | |
plist = list(COMMON_REPLACES[word].values()) | |
plistsum = sum(plist) | |
plist = [x / plistsum for x in plist] | |
# Choose a word | |
repl = npchoice(list(COMMON_REPLACES[word].keys()), p=plist) | |
self.tokenized[index] = repl | |
return ' '.join(self.tokenized) | |
def perturb(self): | |
count = 1 | |
orig_sent = self.sentence | |
for x in range(count): | |
perturb_probs = [.30,.30,.30,.10] | |
perturb_fun = npchoice([self._insert, self._mod_verb, self._replace, self._delete],p=perturb_probs) | |
self.sentence = perturb_fun() | |
self.tokenize() | |
res_sentence = self.sentence | |
self.sentence = self.original_sentence | |
self.tokenize() | |
return res_sentence | |