Bangla-PoS-Taggers / helper /
musfiqdehan's picture
Syncing huggingface space and github
history blame
No virus
3.21 kB
This module contains the helper functions to get the word alignment mapping between two sentences.
import torch
import itertools
import transformers
from transformers import logging
# Set the verbosity to error, so that the warning messages are not printed
def get_alignment_mapping(source="", target="", model_path="musfiqdehan/bn-en-word-aligner"):
Get Aligned Words
model = transformers.BertModel.from_pretrained(model_path)
tokenizer = transformers.BertTokenizer.from_pretrained(model_path)
# pre-processing
sent_src, sent_tgt = source.strip().split(), target.strip().split()
token_src, token_tgt = [tokenizer.tokenize(word) for word in sent_src], [
tokenizer.tokenize(word) for word in sent_tgt]
wid_src, wid_tgt = [tokenizer.convert_tokens_to_ids(x) for x in token_src], [
tokenizer.convert_tokens_to_ids(x) for x in token_tgt]
ids_src, ids_tgt = tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', model_max_length=tokenizer.model_max_length, truncation=True)[
'input_ids'], tokenizer.prepare_for_model(list(itertools.chain(*wid_tgt)), return_tensors='pt', truncation=True, model_max_length=tokenizer.model_max_length)['input_ids']
sub2word_map_src = []
for i, word_list in enumerate(token_src):
sub2word_map_src += [i for x in word_list]
sub2word_map_tgt = []
for i, word_list in enumerate(token_tgt):
sub2word_map_tgt += [i for x in word_list]
# alignment
align_layer = 8
threshold = 1e-3
with torch.no_grad():
out_src = model(ids_src.unsqueeze(0), output_hidden_states=True)[
2][align_layer][0, 1:-1]
out_tgt = model(ids_tgt.unsqueeze(0), output_hidden_states=True)[
2][align_layer][0, 1:-1]
dot_prod = torch.matmul(out_src, out_tgt.transpose(-1, -2))
softmax_srctgt = torch.nn.Softmax(dim=-1)(dot_prod)
softmax_tgtsrc = torch.nn.Softmax(dim=-2)(dot_prod)
softmax_inter = (softmax_srctgt > threshold) * \
(softmax_tgtsrc > threshold)
align_subwords = torch.nonzero(softmax_inter, as_tuple=False)
align_words = set()
for i, j in align_subwords:
align_words.add((sub2word_map_src[i], sub2word_map_tgt[j]))
return sent_src, sent_tgt, align_words
def get_word_mapping(source="", target="", model_path="musfiqdehan/bn-en-word-aligner"):
Get Word Aligned Mapping Words
sent_src, sent_tgt, align_words = get_alignment_mapping(
source=source, target=target, model_path=model_path)
result = []
for i, j in sorted(align_words):
result.append(f'bn:({sent_src[i]}) -> en:({sent_tgt[j]})')
return result
def get_word_index_mapping(source="", target="", model_path="musfiqdehan/bn-en-word-aligner"):
Get Word Aligned Mapping Index
sent_src, sent_tgt, align_words = get_alignment_mapping(
source=source, target=target, model_path=model_path)
result = []
for i, j in sorted(align_words):
result.append(f'bn:({i}) -> en:({j})')
return result