File size: 7,548 Bytes
67a58db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
import torch
import os
from tqdm import tqdm
from .modeling import GECToR
from transformers import PreTrainedTokenizer
from typing import List
def load_verb_dict(verb_file: str):
path_to_dict = os.path.join(verb_file)
encode, decode = {}, {}
with open(path_to_dict, encoding="utf-8") as f:
for line in f:
words, tags = line.split(":")
word1, word2 = words.split("_")
tag1, tag2 = tags.split("_")
decode_key = f"{word1}_{tag1}_{tag2.strip()}"
if decode_key not in decode:
encode[words] = tags
decode[decode_key] = word2
return encode, decode
def edit_src_by_tags(
srcs: List[List[str]],
pred_labels: List[List[str]],
encode: dict,
decode: dict
) -> List[str]:
edited_srcs = []
for tokens, labels in zip(srcs, pred_labels):
edited_tokens = []
for t, l, in zip(tokens, labels):
n_token = process_token(t, l, encode, decode)
if n_token == None:
n_token = t
edited_tokens += n_token.split(' ')
if len(tokens) > len(labels):
omitted_tokens = tokens[len(labels):]
edited_tokens += omitted_tokens
temp_str = ' '.join(edited_tokens) \
.replace(' $MERGE_HYPHEN ', '-') \
.replace(' $MERGE_SPACE ', '') \
.replace(' $DELETE', '') \
.replace('$DELETE ', '')
edited_srcs.append(temp_str.split(' '))
return edited_srcs
def process_token(
token: str,
label: str,
encode: dict,
decode: dict
) -> str:
if '$APPEND_' in label:
return token + ' ' + label.replace('$APPEND_', '')
elif token == '$START':
# [unused1] token cannot be replaced with another token and cannot be deleted.
return token
elif label in ['<PAD>', '<OOV>', '$KEEP']:
return token
elif '$APPEND_' in label:
return token + ' ' + label.replace('$APPEND_', '')
elif '$TRANSFORM_' in label:
return g_transform_processer(token, label, encode, decode)
elif '$REPLACE_' in label:
return label.replace('$REPLACE_', '')
elif label == '$DELETE':
return label
elif '$MERGE_' in label:
return token + ' ' + label
else:
return token
def g_transform_processer(
token: str,
label: str,
encode: dict,
decode: dict
) -> str:
# Case related
if label == '$TRANSFORM_CASE_LOWER':
return token.lower()
elif label == '$TRANSFORM_CASE_UPPER':
return token.upper()
elif label == '$TRANSFORM_CASE_CAPITAL':
return token.capitalize()
elif label == '$TRANSFORM_CASE_CAPITAL_1':
if len(token) <= 1:
return token
return token[0] + token[1:].capitalize()
elif label == '$TRANSFORM_AGREEMENT_PLURAL':
return token + 's'
elif label == '$TRANSFORM_AGREEMENT_SINGULAR':
return token[:-1]
elif label == '$TRANSFORM_SPLIT_HYPHEN':
return ' '.join(token.split('-'))
else:
encoding_part = f"{token}_{label[len('$TRANSFORM_VERB_'):]}"
decoded_target_word = decode.get(encoding_part)
return decoded_target_word
def get_word_masks_from_word_ids(
word_ids: List[List[int]],
n: int
):
word_masks = []
for i in range(n):
previous_id = 0
mask = []
for _id in word_ids(i):
if _id is None:
mask.append(0)
elif previous_id != _id:
mask.append(1)
else:
mask.append(0)
previous_id = _id
word_masks.append(mask)
return word_masks
def _predict(
model: GECToR,
tokenizer: PreTrainedTokenizer,
srcs: List[str],
keep_confidence: float=0,
min_error_prob: float=0,
batch_size: int=128
):
itr = list(range(0, len(srcs), batch_size))
pred_labels = []
no_corrections = []
for i in tqdm(itr):
batch = tokenizer(
srcs[i:i+batch_size],
return_tensors='pt',
max_length=model.config.max_length,
padding='max_length',
truncation=True,
is_split_into_words=True
)
batch['word_masks'] = torch.tensor(
get_word_masks_from_word_ids(
batch.word_ids,
batch['input_ids'].size(0)
)
)
word_ids = batch.word_ids
if torch.cuda.is_available():
batch = {k:v.cuda() for k,v in batch.items()}
outputs = model.predict(
batch['input_ids'],
batch['attention_mask'],
batch['word_masks'],
keep_confidence,
min_error_prob
)
# Align subword-level label to word-level label
for i in range(len(outputs.pred_labels)):
no_correct = True
labels = []
previous_word_idx = None
for j, idx in enumerate(word_ids(i)):
if idx is None:
continue
if idx != previous_word_idx:
labels.append(outputs.pred_labels[i][j])
if outputs.pred_label_ids[i][j] > 2:
no_correct = False
previous_word_idx = idx
# print(no_correct, labels)
pred_labels.append(labels)
no_corrections.append(no_correct)
# print(pred_labels)
return pred_labels, no_corrections
def predict(
model: GECToR,
tokenizer: PreTrainedTokenizer,
srcs: List[str],
encode: dict,
decode: dict,
keep_confidence: float=0,
min_error_prob: float=0,
batch_size: int=128,
n_iteration: int=5
) -> List[str]:
srcs = [['$START'] + src.split(' ') for src in srcs]
final_edited_sents = ['-1'] * len(srcs)
to_be_processed = srcs
original_sent_idx = list(range(0, len(srcs)))
for itr in range(n_iteration):
print(f'Iteratoin {itr}. the number of to_be_processed: {len(to_be_processed)}')
pred_labels, no_corrections = _predict(
model,
tokenizer,
to_be_processed,
keep_confidence,
min_error_prob,
batch_size
)
current_srcs = []
current_pred_labels = []
current_orig_idx = []
for i, yes in enumerate(no_corrections):
if yes: # there's no corrections?
final_edited_sents[original_sent_idx[i]] = ' '.join(to_be_processed[i]).replace('$START ', '')
else:
current_srcs.append(to_be_processed[i])
current_pred_labels.append(pred_labels[i])
current_orig_idx.append(original_sent_idx[i])
if current_srcs == []:
# Correcting for all sentences is completed.
break
# if itr > 2:
# for l in current_pred_labels:
# print(l)
edited_srcs = edit_src_by_tags(
current_srcs,
current_pred_labels,
encode,
decode
)
to_be_processed = edited_srcs
original_sent_idx = current_orig_idx
# print(f'=== Iteration {itr} ===')
# print('\n'.join(final_edited_sents))
# print(to_be_processed)
# print(have_corrections)
for i in range(len(to_be_processed)):
final_edited_sents[original_sent_idx[i]] = ' '.join(to_be_processed[i]).replace('$START ', '')
assert('-1' not in final_edited_sents)
return final_edited_sents |