musfiqdehan commited on
Commit
a45e982
1 Parent(s): cda743d

Refactor alignment_mappers.py to support multiple models

Browse files
Files changed (1) hide show
  1. helper/alignment_mappers.py +29 -10
helper/alignment_mappers.py CHANGED
@@ -12,21 +12,37 @@ logging.set_verbosity_warning()
12
  logging.set_verbosity_error()
13
 
14
 
15
- def get_alignment_mapping(source="", target="", model_path="musfiqdehan/bn-en-word-aligner"):
 
 
 
 
 
 
 
 
 
 
 
 
16
  """
17
  Get Aligned Words
18
  """
19
- model = transformers.BertModel.from_pretrained(model_path)
20
- tokenizer = transformers.BertTokenizer.from_pretrained(model_path)
 
 
21
 
22
  # pre-processing
23
  sent_src, sent_tgt = source.strip().split(), target.strip().split()
 
24
  token_src, token_tgt = [tokenizer.tokenize(word) for word in sent_src], [
25
  tokenizer.tokenize(word) for word in sent_tgt]
 
26
  wid_src, wid_tgt = [tokenizer.convert_tokens_to_ids(x) for x in token_src], [
27
  tokenizer.convert_tokens_to_ids(x) for x in token_tgt]
28
- ids_src, ids_tgt = tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', model_max_length=tokenizer.model_max_length, truncation=True)[
29
- 'input_ids'], tokenizer.prepare_for_model(list(itertools.chain(*wid_tgt)), return_tensors='pt', truncation=True, model_max_length=tokenizer.model_max_length)['input_ids']
30
  sub2word_map_src = []
31
 
32
  for i, word_list in enumerate(token_src):
@@ -69,12 +85,12 @@ def get_alignment_mapping(source="", target="", model_path="musfiqdehan/bn-en-wo
69
 
70
 
71
 
72
- def get_word_mapping(source="", target="", model_path="musfiqdehan/bn-en-word-aligner"):
73
  """
74
  Get Word Aligned Mapping Words
75
  """
76
  sent_src, sent_tgt, align_words = get_alignment_mapping(
77
- source=source, target=target, model_path=model_path)
78
 
79
  result = []
80
 
@@ -85,16 +101,19 @@ def get_word_mapping(source="", target="", model_path="musfiqdehan/bn-en-word-al
85
 
86
 
87
 
88
- def get_word_index_mapping(source="", target="", model_path="musfiqdehan/bn-en-word-aligner"):
89
  """
90
  Get Word Aligned Mapping Index
91
  """
92
  sent_src, sent_tgt, align_words = get_alignment_mapping(
93
- source=source, target=target, model_path=model_path)
94
 
95
  result = []
96
 
97
  for i, j in sorted(align_words):
98
  result.append(f'bn:({i}) -> en:({j})')
99
 
100
- return result
 
 
 
 
12
  logging.set_verbosity_error()
13
 
14
 
15
+ def select_model(model_name):
16
+ """
17
+ Select Model
18
+ """
19
+ if model_name == "Google-mBERT (Base-Multilingual)":
20
+ model_name="bert-base-multilingual-cased"
21
+ elif model_name == "Neulab-AwesomeAlign (Bn-En-0.5M)":
22
+ model_name="musfiqdehan/bn-en-word-aligner"
23
+
24
+ return model_name
25
+
26
+
27
+ def get_alignment_mapping(source="", target="", model_name=""):
28
  """
29
  Get Aligned Words
30
  """
31
+ model_name = select_model(model_name)
32
+
33
+ model = transformers.BertModel.from_pretrained(model_name)
34
+ tokenizer = transformers.BertTokenizer.from_pretrained(model_name)
35
 
36
  # pre-processing
37
  sent_src, sent_tgt = source.strip().split(), target.strip().split()
38
+
39
  token_src, token_tgt = [tokenizer.tokenize(word) for word in sent_src], [
40
  tokenizer.tokenize(word) for word in sent_tgt]
41
+
42
  wid_src, wid_tgt = [tokenizer.convert_tokens_to_ids(x) for x in token_src], [
43
  tokenizer.convert_tokens_to_ids(x) for x in token_tgt]
44
+
45
+ ids_src, ids_tgt = tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', model_max_length=tokenizer.model_max_length, truncation=True)['input_ids'], tokenizer.prepare_for_model(list(itertools.chain(*wid_tgt)), return_tensors='pt', truncation=True, model_max_length=tokenizer.model_max_length)['input_ids']
46
  sub2word_map_src = []
47
 
48
  for i, word_list in enumerate(token_src):
 
85
 
86
 
87
 
88
+ def get_word_mapping(source="", target="", model_name=""):
89
  """
90
  Get Word Aligned Mapping Words
91
  """
92
  sent_src, sent_tgt, align_words = get_alignment_mapping(
93
+ source=source, target=target, model_name=model_name)
94
 
95
  result = []
96
 
 
101
 
102
 
103
 
104
+ def get_word_index_mapping(source="", target="", model_name=""):
105
  """
106
  Get Word Aligned Mapping Index
107
  """
108
  sent_src, sent_tgt, align_words = get_alignment_mapping(
109
+ source=source, target=target, model_name=model_name)
110
 
111
  result = []
112
 
113
  for i, j in sorted(align_words):
114
  result.append(f'bn:({i}) -> en:({j})')
115
 
116
+ return result
117
+
118
+
119
+