"""Wrapper of AllenNLP model. Fixes errors based on model predictions""" import logging import os import sys from time import time import torch from allennlp.data.dataset import Batch from allennlp.data.fields import TextField from allennlp.data.instance import Instance from allennlp.data.tokenizers import Token from allennlp.data.vocabulary import Vocabulary from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder from allennlp.nn import util from gector.bert_token_embedder import PretrainedBertEmbedder from gector.seq2labels_model import Seq2Labels from gector.tokenizer_indexer import PretrainedBertIndexer from utils.helpers import PAD, UNK, get_target_sent_by_edits, START_TOKEN from utils.helpers import get_weights_name logging.getLogger("werkzeug").setLevel(logging.ERROR) logger = logging.getLogger(__file__) class GecBERTModel(object): def __init__(self, vocab_path=None, model_paths=None, weigths=None, max_len=50, min_len=3, lowercase_tokens=False, log=False, iterations=3, model_name='roberta', special_tokens_fix=1, is_ensemble=True, min_error_probability=0.0, confidence=0, del_confidence=0, resolve_cycles=False, ): self.model_weights = list(map(float, weigths)) if weigths else [1] * len(model_paths) self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.max_len = max_len self.min_len = min_len self.lowercase_tokens = lowercase_tokens self.min_error_probability = min_error_probability self.vocab = Vocabulary.from_files(vocab_path) self.log = log self.iterations = iterations self.confidence = confidence self.del_conf = del_confidence self.resolve_cycles = resolve_cycles # set training parameters and operations self.indexers = [] self.models = [] for model_path in model_paths: if is_ensemble: model_name, special_tokens_fix = self._get_model_data(model_path) weights_name = get_weights_name(model_name, lowercase_tokens) self.indexers.append(self._get_indexer(weights_name, special_tokens_fix)) model = Seq2Labels(vocab=self.vocab, text_field_embedder=self._get_embbeder(weights_name, special_tokens_fix), confidence=self.confidence, del_confidence=self.del_conf, ).to(self.device) if torch.cuda.is_available(): model.load_state_dict(torch.load(model_path), strict=False) else: model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False) model.eval() self.models.append(model) @staticmethod def _get_model_data(model_path): model_name = model_path.split('/')[-1] tr_model, stf = model_name.split('_')[:2] return tr_model, int(stf) def _restore_model(self, input_path): if os.path.isdir(input_path): print("Model could not be restored from directory", file=sys.stderr) filenames = [] else: filenames = [input_path] for model_path in filenames: try: if torch.cuda.is_available(): loaded_model = torch.load(model_path) else: loaded_model = torch.load(model_path, map_location=lambda storage, loc: storage) except: print(f"{model_path} is not valid model", file=sys.stderr) own_state = self.model.state_dict() for name, weights in loaded_model.items(): if name not in own_state: continue try: if len(filenames) == 1: own_state[name].copy_(weights) else: own_state[name] += weights except RuntimeError: continue print("Model is restored", file=sys.stderr) def predict(self, batches): t11 = time() predictions = [] for batch, model in zip(batches, self.models): batch = util.move_to_device(batch.as_tensor_dict(), 0 if torch.cuda.is_available() else -1) with torch.no_grad(): prediction = model.forward(**batch) predictions.append(prediction) preds, idx, error_probs = self._convert(predictions) t55 = time() if self.log: print(f"Inference time {t55 - t11}") return preds, idx, error_probs def get_token_action(self, token, index, prob, sugg_token): """Get lost of suggested actions for token.""" # cases when we don't need to do anything if prob < self.min_error_probability or sugg_token in [UNK, PAD, '$KEEP']: return None if sugg_token.startswith('$REPLACE_') or sugg_token.startswith('$TRANSFORM_') or sugg_token == '$DELETE': start_pos = index end_pos = index + 1 elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"): start_pos = index + 1 end_pos = index + 1 if sugg_token == "$DELETE": sugg_token_clear = "" elif sugg_token.startswith('$TRANSFORM_') or sugg_token.startswith("$MERGE_"): sugg_token_clear = sugg_token[:] else: sugg_token_clear = sugg_token[sugg_token.index('_') + 1:] return start_pos - 1, end_pos - 1, sugg_token_clear, prob def _get_embbeder(self, weigths_name, special_tokens_fix): embedders = {'bert': PretrainedBertEmbedder( pretrained_model=weigths_name, requires_grad=False, top_layer_only=True, special_tokens_fix=special_tokens_fix) } text_field_embedder = BasicTextFieldEmbedder( token_embedders=embedders, embedder_to_indexer_map={"bert": ["bert", "bert-offsets"]}, allow_unmatched_keys=True) return text_field_embedder def _get_indexer(self, weights_name, special_tokens_fix): bert_token_indexer = PretrainedBertIndexer( pretrained_model=weights_name, do_lowercase=self.lowercase_tokens, max_pieces_per_token=5, special_tokens_fix=special_tokens_fix ) return {'bert': bert_token_indexer} def preprocess(self, token_batch): seq_lens = [len(sequence) for sequence in token_batch if sequence] if not seq_lens: return [] max_len = min(max(seq_lens), self.max_len) batches = [] for indexer in self.indexers: batch = [] for sequence in token_batch: tokens = sequence[:max_len] tokens = [Token(token) for token in ['$START'] + tokens] batch.append(Instance({'tokens': TextField(tokens, indexer)})) batch = Batch(batch) batch.index_instances(self.vocab) batches.append(batch) return batches def _convert(self, data): all_class_probs = torch.zeros_like(data[0]['class_probabilities_labels']) error_probs = torch.zeros_like(data[0]['max_error_probability']) for output, weight in zip(data, self.model_weights): all_class_probs += weight * output['class_probabilities_labels'] / sum(self.model_weights) error_probs += weight * output['max_error_probability'] / sum(self.model_weights) max_vals = torch.max(all_class_probs, dim=-1) probs = max_vals[0].tolist() idx = max_vals[1].tolist() return probs, idx, error_probs.tolist() def update_final_batch(self, final_batch, pred_ids, pred_batch, prev_preds_dict): new_pred_ids = [] total_updated = 0 for i, orig_id in enumerate(pred_ids): orig = final_batch[orig_id] pred = pred_batch[i] prev_preds = prev_preds_dict[orig_id] if orig != pred and pred not in prev_preds: final_batch[orig_id] = pred new_pred_ids.append(orig_id) prev_preds_dict[orig_id].append(pred) total_updated += 1 elif orig != pred and pred in prev_preds: # update final batch, but stop iterations final_batch[orig_id] = pred total_updated += 1 else: continue return final_batch, new_pred_ids, total_updated def postprocess_batch(self, batch, all_probabilities, all_idxs, error_probs): all_results = [] noop_index = self.vocab.get_token_index("$KEEP", "labels") for tokens, probabilities, idxs, error_prob in zip(batch, all_probabilities, all_idxs, error_probs): length = min(len(tokens), self.max_len) edits = [] # skip whole sentences if there no errors if max(idxs) == 0: all_results.append(tokens) continue # skip whole sentence if probability of correctness is not high if error_prob < self.min_error_probability: all_results.append(tokens) continue for i in range(length + 1): # because of START token if i == 0: token = START_TOKEN else: token = tokens[i - 1] # skip if there is no error if idxs[i] == noop_index: continue sugg_token = self.vocab.get_token_from_index(idxs[i], namespace='labels') action = self.get_token_action(token, i, probabilities[i], sugg_token) if not action: continue edits.append(action) all_results.append(get_target_sent_by_edits(tokens, edits)) return all_results def handle_batch(self, full_batch): """ Handle batch of requests. """ final_batch = full_batch[:] batch_size = len(full_batch) prev_preds_dict = {i: [final_batch[i]] for i in range(len(final_batch))} short_ids = [i for i in range(len(full_batch)) if len(full_batch[i]) < self.min_len] pred_ids = [i for i in range(len(full_batch)) if i not in short_ids] total_updates = 0 for n_iter in range(self.iterations): orig_batch = [final_batch[i] for i in pred_ids] sequences = self.preprocess(orig_batch) if not sequences: break probabilities, idxs, error_probs = self.predict(sequences) pred_batch = self.postprocess_batch(orig_batch, probabilities, idxs, error_probs) if self.log: print(f"Iteration {n_iter + 1}. Predicted {round(100*len(pred_ids)/batch_size, 1)}% of sentences.") final_batch, pred_ids, cnt = \ self.update_final_batch(final_batch, pred_ids, pred_batch, prev_preds_dict) total_updates += cnt if not pred_ids: break return final_batch, total_updates