Hussain Shaikh commited on
Commit
7edceed
·
1 Parent(s): 6325f49

final commit added required files

Browse files
.gitignore ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ignore libs folder we use
2
+ indic_nlp_library
3
+ indic_nlp_resources
4
+ subword-nmt
5
+
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
103
+ __pypackages__/
104
+
105
+ # Celery stuff
106
+ celerybeat-schedule
107
+ celerybeat.pid
108
+
109
+ # SageMath parsed files
110
+ *.sage.py
111
+
112
+ # Environments
113
+ .env
114
+ .venv
115
+ env/
116
+ venv/
117
+ ENV/
118
+ env.bak/
119
+ venv.bak/
120
+
121
+ # Spyder project settings
122
+ .spyderproject
123
+ .spyproject
124
+
125
+ # Rope project settings
126
+ .ropeproject
127
+
128
+ # mkdocs documentation
129
+ /site
130
+
131
+ # mypy
132
+ .mypy_cache/
133
+ .dmypy.json
134
+ dmypy.json
135
+
136
+ # Pyre type checker
137
+ .pyre/
138
+
139
+ # pytype static type analyzer
140
+ .pytype/
141
+
142
+ # Cython debug symbols
143
+ cython_debug/
api/api.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import re
4
+ from math import floor, ceil
5
+ from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
6
+ # from nltk.tokenize import sent_tokenize
7
+ from flask import Flask, request, jsonify
8
+ from flask_cors import CORS, cross_origin
9
+ import webvtt
10
+ from io import StringIO
11
+ from mosestokenizer import MosesSentenceSplitter
12
+
13
+ from indicTrans.inference.engine import Model
14
+ from punctuate import RestorePuncts
15
+ from indicnlp.tokenize.sentence_tokenize import sentence_split
16
+
17
+ app = Flask(__name__)
18
+ cors = CORS(app)
19
+ app.config['CORS_HEADERS'] = 'Content-Type'
20
+
21
+ indic2en_model = Model(expdir='models/v3/indic-en')
22
+ en2indic_model = Model(expdir='models/v3/en-indic')
23
+ m2m_model = Model(expdir='models/m2m')
24
+
25
+ rpunct = RestorePuncts()
26
+
27
+ indic_language_dict = {
28
+ 'Assamese': 'as',
29
+ 'Hindi' : 'hi',
30
+ 'Marathi' : 'mr',
31
+ 'Tamil' : 'ta',
32
+ 'Bengali' : 'bn',
33
+ 'Kannada' : 'kn',
34
+ 'Oriya' : 'or',
35
+ 'Telugu' : 'te',
36
+ 'Gujarati' : 'gu',
37
+ 'Malayalam' : 'ml',
38
+ 'Punjabi' : 'pa',
39
+ }
40
+
41
+ splitter = MosesSentenceSplitter('en')
42
+
43
+ def get_inference_params():
44
+ source_language = request.form['source_language']
45
+ target_language = request.form['target_language']
46
+
47
+ if source_language in indic_language_dict and target_language == 'English':
48
+ model = indic2en_model
49
+ source_lang = indic_language_dict[source_language]
50
+ target_lang = 'en'
51
+ elif source_language == 'English' and target_language in indic_language_dict:
52
+ model = en2indic_model
53
+ source_lang = 'en'
54
+ target_lang = indic_language_dict[target_language]
55
+ elif source_language in indic_language_dict and target_language in indic_language_dict:
56
+ model = m2m_model
57
+ source_lang = indic_language_dict[source_language]
58
+ target_lang = indic_language_dict[target_language]
59
+
60
+ return model, source_lang, target_lang
61
+
62
+ @app.route('/', methods=['GET'])
63
+ def main():
64
+ return "IndicTrans API"
65
+
66
+ @app.route('/supported_languages', methods=['GET'])
67
+ @cross_origin()
68
+ def supported_languages():
69
+ return jsonify(indic_language_dict)
70
+
71
+ @app.route("/translate", methods=['POST'])
72
+ @cross_origin()
73
+ def infer_indic_en():
74
+ model, source_lang, target_lang = get_inference_params()
75
+ source_text = request.form['text']
76
+
77
+ start_time = time.time()
78
+ target_text = model.translate_paragraph(source_text, source_lang, target_lang)
79
+ end_time = time.time()
80
+ return {'text':target_text, 'duration':round(end_time-start_time, 2)}
81
+
82
+ @app.route("/translate_vtt", methods=['POST'])
83
+ @cross_origin()
84
+ def infer_vtt_indic_en():
85
+ start_time = time.time()
86
+ model, source_lang, target_lang = get_inference_params()
87
+ source_text = request.form['text']
88
+ # vad_segments = request.form['vad_nochunk'] # Assuming it is an array of start & end timestamps
89
+
90
+ vad = webvtt.read_buffer(StringIO(source_text))
91
+ source_sentences = [v.text.replace('\r', '').replace('\n', ' ') for v in vad]
92
+
93
+ ## SUMANTH LOGIC HERE ##
94
+
95
+ # for each vad timestamp, do:
96
+ large_sentence = ' '.join(source_sentences) # only sentences in that time range
97
+ large_sentence = large_sentence.lower()
98
+ # split_sents = sentence_split(large_sentence, 'en')
99
+ # print(split_sents)
100
+
101
+ large_sentence = re.sub(r'[^\w\s]', '', large_sentence)
102
+ punctuated = rpunct.punctuate(large_sentence, batch_size=32)
103
+ end_time = time.time()
104
+ print("Time Taken for punctuation: {} s".format(end_time - start_time))
105
+ start_time = time.time()
106
+ split_sents = splitter([punctuated]) ### Please uncomment
107
+
108
+
109
+ # print(split_sents)
110
+ # output_sentence_punctuated = model.translate_paragraph(punctuated, source_lang, target_lang)
111
+ output_sents = model.batch_translate(split_sents, source_lang, target_lang)
112
+ # print(output_sents)
113
+ # output_sents = split_sents
114
+ # print(output_sents)
115
+ # align this to those range of source_sentences in `captions`
116
+
117
+ map_ = {split_sents[i] : output_sents[i] for i in range(len(split_sents))}
118
+ # print(map_)
119
+ punct_para = ' '.join(list(map_.keys()))
120
+ nmt_para = ' '.join(list(map_.values()))
121
+ nmt_words = nmt_para.split(' ')
122
+
123
+ len_punct = len(punct_para.split(' '))
124
+ len_nmt = len(nmt_para.split(' '))
125
+
126
+ start = 0
127
+ for i in range(len(vad)):
128
+ if vad[i].text == '':
129
+ continue
130
+
131
+ len_caption = len(vad[i].text.split(' '))
132
+ frac = (len_caption / len_punct)
133
+ # frac = round(frac, 2)
134
+
135
+ req_nmt_size = floor(frac * len_nmt)
136
+ # print(frac, req_nmt_size)
137
+
138
+ vad[i].text = ' '.join(nmt_words[start:start+req_nmt_size])
139
+ # print(vad[i].text)
140
+ # print(start, req_nmt_size)
141
+ start += req_nmt_size
142
+
143
+ end_time = time.time()
144
+
145
+ print("Time Taken for translation: {} s".format(end_time - start_time))
146
+
147
+ # vad.save('aligned.vtt')
148
+
149
+ return {
150
+ 'text': vad.content,
151
+ # 'duration':round(end_time-start_time, 2)
152
+ }
api/punctuate.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # 💾⚙️🔮
3
+
4
+ # taken from https://github.com/Felflare/rpunct/blob/master/rpunct/punctuate.py
5
+ # modified to support batching during gpu inference
6
+
7
+
8
+ __author__ = "Daulet N."
9
+ __email__ = "daulet.nurmanbetov@gmail.com"
10
+
11
+ import time
12
+ import logging
13
+ import webvtt
14
+ import torch
15
+ from io import StringIO
16
+ from nltk.tokenize import sent_tokenize
17
+ #from langdetect import detect
18
+ from simpletransformers.ner import NERModel
19
+
20
+
21
+ class RestorePuncts:
22
+ def __init__(self, wrds_per_pred=250):
23
+ self.wrds_per_pred = wrds_per_pred
24
+ self.overlap_wrds = 30
25
+ self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U']
26
+ self.model = NERModel("bert", "felflare/bert-restore-punctuation", labels=self.valid_labels,
27
+ args={"silent": True, "max_seq_length": 512})
28
+ # use_cuda isnt working and this hack seems to load the model correctly to the gpu
29
+ self.model.device = torch.device("cuda:1")
30
+ # dummy punctuate to load the model onto gpu
31
+ self.punctuate("hello how are you")
32
+
33
+ def punctuate(self, text: str, batch_size:int=32, lang:str=''):
34
+ """
35
+ Performs punctuation restoration on arbitrarily large text.
36
+ Detects if input is not English, if non-English was detected terminates predictions.
37
+ Overrride by supplying `lang='en'`
38
+
39
+ Args:
40
+ - text (str): Text to punctuate, can be few words to as large as you want.
41
+ - lang (str): Explicit language of input text.
42
+ """
43
+ #if not lang and len(text) > 10:
44
+ # lang = detect(text)
45
+ #if lang != 'en':
46
+ # raise Exception(F"""Non English text detected. Restore Punctuation works only for English.
47
+ # If you are certain the input is English, pass argument lang='en' to this function.
48
+ # Punctuate received: {text}""")
49
+
50
+ def chunks(L, n):
51
+ return [L[x : x + n] for x in range(0, len(L), n)]
52
+
53
+
54
+
55
+ # plit up large text into bert digestable chunks
56
+ splits = self.split_on_toks(text, self.wrds_per_pred, self.overlap_wrds)
57
+
58
+ texts = [i["text"] for i in splits]
59
+ batches = chunks(texts, batch_size)
60
+ preds_lst = []
61
+
62
+
63
+ for batch in batches:
64
+ batch_preds, _ = self.model.predict(batch)
65
+ preds_lst.extend(batch_preds)
66
+
67
+
68
+ # predict slices
69
+ # full_preds_lst contains tuple of labels and logits
70
+ #full_preds_lst = [self.predict(i['text']) for i in splits]
71
+ # extract predictions, and discard logits
72
+ #preds_lst = [i[0][0] for i in full_preds_lst]
73
+ # join text slices
74
+ combined_preds = self.combine_results(text, preds_lst)
75
+ # create punctuated prediction
76
+ punct_text = self.punctuate_texts(combined_preds)
77
+ return punct_text
78
+
79
+ def predict(self, input_slice):
80
+ """
81
+ Passes the unpunctuated text to the model for punctuation.
82
+ """
83
+ predictions, raw_outputs = self.model.predict([input_slice])
84
+ return predictions, raw_outputs
85
+
86
+ @staticmethod
87
+ def split_on_toks(text, length, overlap):
88
+ """
89
+ Splits text into predefined slices of overlapping text with indexes (offsets)
90
+ that tie-back to original text.
91
+ This is done to bypass 512 token limit on transformer models by sequentially
92
+ feeding chunks of < 512 toks.
93
+ Example output:
94
+ [{...}, {"text": "...", 'start_idx': 31354, 'end_idx': 32648}, {...}]
95
+ """
96
+ wrds = text.replace('\n', ' ').split(" ")
97
+ resp = []
98
+ lst_chunk_idx = 0
99
+ i = 0
100
+
101
+ while True:
102
+ # words in the chunk and the overlapping portion
103
+ wrds_len = wrds[(length * i):(length * (i + 1))]
104
+ wrds_ovlp = wrds[(length * (i + 1)):((length * (i + 1)) + overlap)]
105
+ wrds_split = wrds_len + wrds_ovlp
106
+
107
+ # Break loop if no more words
108
+ if not wrds_split:
109
+ break
110
+
111
+ wrds_str = " ".join(wrds_split)
112
+ nxt_chunk_start_idx = len(" ".join(wrds_len))
113
+ lst_char_idx = len(" ".join(wrds_split))
114
+
115
+ resp_obj = {
116
+ "text": wrds_str,
117
+ "start_idx": lst_chunk_idx,
118
+ "end_idx": lst_char_idx + lst_chunk_idx,
119
+ }
120
+
121
+ resp.append(resp_obj)
122
+ lst_chunk_idx += nxt_chunk_start_idx + 1
123
+ i += 1
124
+ logging.info(f"Sliced transcript into {len(resp)} slices.")
125
+ return resp
126
+
127
+ @staticmethod
128
+ def combine_results(full_text: str, text_slices):
129
+ """
130
+ Given a full text and predictions of each slice combines predictions into a single text again.
131
+ Performs validataion wether text was combined correctly
132
+ """
133
+ split_full_text = full_text.replace('\n', ' ').split(" ")
134
+ split_full_text = [i for i in split_full_text if i]
135
+ split_full_text_len = len(split_full_text)
136
+ output_text = []
137
+ index = 0
138
+
139
+ if len(text_slices[-1]) <= 3 and len(text_slices) > 1:
140
+ text_slices = text_slices[:-1]
141
+
142
+ for _slice in text_slices:
143
+ slice_wrds = len(_slice)
144
+ for ix, wrd in enumerate(_slice):
145
+ # print(index, "|", str(list(wrd.keys())[0]), "|", split_full_text[index])
146
+ if index == split_full_text_len:
147
+ break
148
+
149
+ if split_full_text[index] == str(list(wrd.keys())[0]) and \
150
+ ix <= slice_wrds - 3 and text_slices[-1] != _slice:
151
+ index += 1
152
+ pred_item_tuple = list(wrd.items())[0]
153
+ output_text.append(pred_item_tuple)
154
+ elif split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == _slice:
155
+ index += 1
156
+ pred_item_tuple = list(wrd.items())[0]
157
+ output_text.append(pred_item_tuple)
158
+ assert [i[0] for i in output_text] == split_full_text
159
+ return output_text
160
+
161
+ @staticmethod
162
+ def punctuate_texts(full_pred: list):
163
+ """
164
+ Given a list of Predictions from the model, applies the predictions to text,
165
+ thus punctuating it.
166
+ """
167
+ punct_resp = ""
168
+ for i in full_pred:
169
+ word, label = i
170
+ if label[-1] == "U":
171
+ punct_wrd = word.capitalize()
172
+ else:
173
+ punct_wrd = word
174
+
175
+ if label[0] != "O":
176
+ punct_wrd += label[0]
177
+
178
+ punct_resp += punct_wrd + " "
179
+ punct_resp = punct_resp.strip()
180
+ # Append trailing period if doesnt exist.
181
+ if punct_resp[-1].isalnum():
182
+ punct_resp += "."
183
+ return punct_resp
184
+
185
+
186
+ if __name__ == "__main__":
187
+
188
+ start = time.time()
189
+ punct_model = RestorePuncts()
190
+
191
+ load_model = time.time()
192
+ print(f'Time to load model: {load_model - start}')
193
+ # read test file
194
+ # with open('en_lower.txt', 'r') as fp:
195
+ # # test_sample = fp.read()
196
+ # lines = fp.readlines()
197
+
198
+ with open('sample.vtt', 'r') as fp:
199
+ source_text = fp.read()
200
+
201
+ # captions = webvtt.read_buffer(StringIO(source_text))
202
+ captions = webvtt.read('sample.vtt')
203
+ source_sentences = [caption.text.replace('\r', '').replace('\n', ' ') for caption in captions]
204
+
205
+ # print(source_sentences)
206
+
207
+ sent = ' '.join(source_sentences)
208
+ punctuated = punct_model.punctuate(sent)
209
+
210
+ tokenised = sent_tokenize(punctuated)
211
+ # print(tokenised)
212
+
213
+ for i in range(len(tokenised)):
214
+ captions[i].text = tokenised[i]
215
+ # return captions.content
216
+ captions.save('my_captions.vtt')
217
+
218
+ end = time.time()
219
+ print(f'Time for run: {end - load_model}')
220
+ print(f'Total time: {end - start}')
app.py CHANGED
@@ -1,7 +1,30 @@
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hell" + name + "!!"
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
 
4
+ download="wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1IpcnaQ2ScX_zodt2aLlXa_5Kkntl0nue' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\\n/p')&id=1IpcnaQ2ScX_zodt2aLlXa_5Kkntl0nue\" -O en-indic.zip && rm -rf /tmp/cookies.txt"
5
+ os.system(download)
6
+ os.system('unzip /home/user/app/en-indic.zip')
7
 
8
+ from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
9
+ import gradio as gr
10
+ from inference.engine import Model
11
+ indic2en_model = Model(expdir='/home/user/app/en-indic')
12
+
13
+ INDIC = {"Assamese": "as", "Bengali": "bn", "Gujarati": "gu", "Hindi": "hi","Kannada": "kn","Malayalam": "ml", "Marathi": "mr", "Odia": "or","Punjabi": "pa","Tamil": "ta", "Telugu" : "te"}
14
+
15
+
16
+ def translate(text, lang):
17
+ return indic2en_model.translate_paragraph(text, 'en', INDIC[lang])
18
+
19
+
20
+ languages = list(INDIC.keys())
21
+ drop_down = gr.inputs.Dropdown(languages, type="value", default="Hindi", label="Select Target Language")
22
+ text = gr.inputs.Textbox(lines=5, placeholder="Enter Text to translate", default="", label="Enter Text in English")
23
+ text_ouptut = gr.outputs.Textbox(type="auto", label="Translated text in Target Language")
24
+
25
+ # example=[['I want to translate this sentence in Hindi','Hindi'],
26
+ # ['I am feeling very good today.', 'Bengali']]
27
+
28
+ supported_lang = ', '.join(languages)
29
+ iface = gr.Interface(fn=translate, inputs=[text,drop_down] , outputs=text_ouptut, title='IndicTrans NMT System', description = 'Currently the model supports ' + supported_lang, article = 'Original repository can be found [here](https://github.com/AI4Bharat/indicTrans)' , examples=None)
30
+ iface.launch(enable_queue=True)
inference/__init__.py ADDED
File without changes
inference/custom_interactive.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python wrapper for fairseq-interactive command line tool
2
+
3
+ #!/usr/bin/env python3 -u
4
+ # Copyright (c) Facebook, Inc. and its affiliates.
5
+ #
6
+ # This source code is licensed under the MIT license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+ """
9
+ Translate raw text with a trained model. Batches data on-the-fly.
10
+ """
11
+
12
+ import ast
13
+ from collections import namedtuple
14
+
15
+ import torch
16
+ from fairseq import checkpoint_utils, options, tasks, utils
17
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
18
+ from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
19
+ from fairseq_cli.generate import get_symbols_to_strip_from_output
20
+
21
+ import codecs
22
+
23
+
24
+ Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints")
25
+ Translation = namedtuple("Translation", "src_str hypos pos_scores alignments")
26
+
27
+
28
+ def make_batches(
29
+ lines, cfg, task, max_positions, encode_fn, constrainted_decoding=False
30
+ ):
31
+ def encode_fn_target(x):
32
+ return encode_fn(x)
33
+
34
+ if constrainted_decoding:
35
+ # Strip (tab-delimited) contraints, if present, from input lines,
36
+ # store them in batch_constraints
37
+ batch_constraints = [list() for _ in lines]
38
+ for i, line in enumerate(lines):
39
+ if "\t" in line:
40
+ lines[i], *batch_constraints[i] = line.split("\t")
41
+
42
+ # Convert each List[str] to List[Tensor]
43
+ for i, constraint_list in enumerate(batch_constraints):
44
+ batch_constraints[i] = [
45
+ task.target_dictionary.encode_line(
46
+ encode_fn_target(constraint),
47
+ append_eos=False,
48
+ add_if_not_exist=False,
49
+ )
50
+ for constraint in constraint_list
51
+ ]
52
+
53
+ if constrainted_decoding:
54
+ constraints_tensor = pack_constraints(batch_constraints)
55
+ else:
56
+ constraints_tensor = None
57
+
58
+ tokens, lengths = task.get_interactive_tokens_and_lengths(lines, encode_fn)
59
+
60
+ itr = task.get_batch_iterator(
61
+ dataset=task.build_dataset_for_inference(
62
+ tokens, lengths, constraints=constraints_tensor
63
+ ),
64
+ max_tokens=cfg.dataset.max_tokens,
65
+ max_sentences=cfg.dataset.batch_size,
66
+ max_positions=max_positions,
67
+ ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
68
+ ).next_epoch_itr(shuffle=False)
69
+ for batch in itr:
70
+ ids = batch["id"]
71
+ src_tokens = batch["net_input"]["src_tokens"]
72
+ src_lengths = batch["net_input"]["src_lengths"]
73
+ constraints = batch.get("constraints", None)
74
+
75
+ yield Batch(
76
+ ids=ids,
77
+ src_tokens=src_tokens,
78
+ src_lengths=src_lengths,
79
+ constraints=constraints,
80
+ )
81
+
82
+
83
+ class Translator:
84
+ def __init__(
85
+ self, data_dir, checkpoint_path, batch_size=25, constrained_decoding=False
86
+ ):
87
+
88
+ self.constrained_decoding = constrained_decoding
89
+ self.parser = options.get_generation_parser(interactive=True)
90
+ # buffer_size is currently not used but we just initialize it to batch
91
+ # size + 1 to avoid any assertion errors.
92
+ if self.constrained_decoding:
93
+ self.parser.set_defaults(
94
+ path=checkpoint_path,
95
+ remove_bpe="subword_nmt",
96
+ num_workers=-1,
97
+ constraints="ordered",
98
+ batch_size=batch_size,
99
+ buffer_size=batch_size + 1,
100
+ )
101
+ else:
102
+ self.parser.set_defaults(
103
+ path=checkpoint_path,
104
+ remove_bpe="subword_nmt",
105
+ num_workers=-1,
106
+ batch_size=batch_size,
107
+ buffer_size=batch_size + 1,
108
+ )
109
+ args = options.parse_args_and_arch(self.parser, input_args=[data_dir])
110
+ # we are explictly setting src_lang and tgt_lang here
111
+ # generally the data_dir we pass contains {split}-{src_lang}-{tgt_lang}.*.idx files from
112
+ # which fairseq infers the src and tgt langs(if these are not passed). In deployment we dont
113
+ # use any idx files and only store the SRC and TGT dictionaries.
114
+ args.source_lang = "SRC"
115
+ args.target_lang = "TGT"
116
+ # since we are truncating sentences to max_seq_len in engine, we can set it to False here
117
+ args.skip_invalid_size_inputs_valid_test = False
118
+
119
+ # we have custom architechtures in this folder and we will let fairseq
120
+ # import this
121
+ args.user_dir = "model_configs"
122
+ self.cfg = convert_namespace_to_omegaconf(args)
123
+
124
+ utils.import_user_module(self.cfg.common)
125
+
126
+ if self.cfg.interactive.buffer_size < 1:
127
+ self.cfg.interactive.buffer_size = 1
128
+ if self.cfg.dataset.max_tokens is None and self.cfg.dataset.batch_size is None:
129
+ self.cfg.dataset.batch_size = 1
130
+
131
+ assert (
132
+ not self.cfg.generation.sampling
133
+ or self.cfg.generation.nbest == self.cfg.generation.beam
134
+ ), "--sampling requires --nbest to be equal to --beam"
135
+ assert (
136
+ not self.cfg.dataset.batch_size
137
+ or self.cfg.dataset.batch_size <= self.cfg.interactive.buffer_size
138
+ ), "--batch-size cannot be larger than --buffer-size"
139
+
140
+ # Fix seed for stochastic decoding
141
+ # if self.cfg.common.seed is not None and not self.cfg.generation.no_seed_provided:
142
+ # np.random.seed(self.cfg.common.seed)
143
+ # utils.set_torch_seed(self.cfg.common.seed)
144
+
145
+ # if not self.constrained_decoding:
146
+ # self.use_cuda = torch.cuda.is_available() and not self.cfg.common.cpu
147
+ # else:
148
+ # self.use_cuda = False
149
+
150
+ self.use_cuda = torch.cuda.is_available() and not self.cfg.common.cpu
151
+
152
+ # Setup task, e.g., translation
153
+ self.task = tasks.setup_task(self.cfg.task)
154
+
155
+ # Load ensemble
156
+ overrides = ast.literal_eval(self.cfg.common_eval.model_overrides)
157
+ self.models, self._model_args = checkpoint_utils.load_model_ensemble(
158
+ utils.split_paths(self.cfg.common_eval.path),
159
+ arg_overrides=overrides,
160
+ task=self.task,
161
+ suffix=self.cfg.checkpoint.checkpoint_suffix,
162
+ strict=(self.cfg.checkpoint.checkpoint_shard_count == 1),
163
+ num_shards=self.cfg.checkpoint.checkpoint_shard_count,
164
+ )
165
+
166
+ # Set dictionaries
167
+ self.src_dict = self.task.source_dictionary
168
+ self.tgt_dict = self.task.target_dictionary
169
+
170
+ # Optimize ensemble for generation
171
+ for model in self.models:
172
+ if model is None:
173
+ continue
174
+ if self.cfg.common.fp16:
175
+ model.half()
176
+ if (
177
+ self.use_cuda
178
+ and not self.cfg.distributed_training.pipeline_model_parallel
179
+ ):
180
+ model.cuda()
181
+ model.prepare_for_inference_(self.cfg)
182
+
183
+ # Initialize generator
184
+ self.generator = self.task.build_generator(self.models, self.cfg.generation)
185
+
186
+ # Handle tokenization and BPE
187
+ self.tokenizer = self.task.build_tokenizer(self.cfg.tokenizer)
188
+ self.bpe = self.task.build_bpe(self.cfg.bpe)
189
+
190
+ # Load alignment dictionary for unknown word replacement
191
+ # (None if no unknown word replacement, empty if no path to align dictionary)
192
+ self.align_dict = utils.load_align_dict(self.cfg.generation.replace_unk)
193
+
194
+ self.max_positions = utils.resolve_max_positions(
195
+ self.task.max_positions(), *[model.max_positions() for model in self.models]
196
+ )
197
+
198
+ def encode_fn(self, x):
199
+ if self.tokenizer is not None:
200
+ x = self.tokenizer.encode(x)
201
+ if self.bpe is not None:
202
+ x = self.bpe.encode(x)
203
+ return x
204
+
205
+ def decode_fn(self, x):
206
+ if self.bpe is not None:
207
+ x = self.bpe.decode(x)
208
+ if self.tokenizer is not None:
209
+ x = self.tokenizer.decode(x)
210
+ return x
211
+
212
+ def translate(self, inputs, constraints=None):
213
+ if self.constrained_decoding and constraints is None:
214
+ raise ValueError("Constraints cant be None in constrained decoding mode")
215
+ if not self.constrained_decoding and constraints is not None:
216
+ raise ValueError("Cannot pass constraints during normal translation")
217
+ if constraints:
218
+ constrained_decoding = True
219
+ modified_inputs = []
220
+ for _input, constraint in zip(inputs, constraints):
221
+ modified_inputs.append(_input + f"\t{constraint}")
222
+ inputs = modified_inputs
223
+ else:
224
+ constrained_decoding = False
225
+
226
+ start_id = 0
227
+ results = []
228
+ final_translations = []
229
+ for batch in make_batches(
230
+ inputs,
231
+ self.cfg,
232
+ self.task,
233
+ self.max_positions,
234
+ self.encode_fn,
235
+ constrained_decoding,
236
+ ):
237
+ bsz = batch.src_tokens.size(0)
238
+ src_tokens = batch.src_tokens
239
+ src_lengths = batch.src_lengths
240
+ constraints = batch.constraints
241
+ if self.use_cuda:
242
+ src_tokens = src_tokens.cuda()
243
+ src_lengths = src_lengths.cuda()
244
+ if constraints is not None:
245
+ constraints = constraints.cuda()
246
+
247
+ sample = {
248
+ "net_input": {
249
+ "src_tokens": src_tokens,
250
+ "src_lengths": src_lengths,
251
+ },
252
+ }
253
+
254
+ translations = self.task.inference_step(
255
+ self.generator, self.models, sample, constraints=constraints
256
+ )
257
+
258
+ list_constraints = [[] for _ in range(bsz)]
259
+ if constrained_decoding:
260
+ list_constraints = [unpack_constraints(c) for c in constraints]
261
+ for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
262
+ src_tokens_i = utils.strip_pad(src_tokens[i], self.tgt_dict.pad())
263
+ constraints = list_constraints[i]
264
+ results.append(
265
+ (
266
+ start_id + id,
267
+ src_tokens_i,
268
+ hypos,
269
+ {
270
+ "constraints": constraints,
271
+ },
272
+ )
273
+ )
274
+
275
+ # sort output to match input order
276
+ for id_, src_tokens, hypos, _ in sorted(results, key=lambda x: x[0]):
277
+ src_str = ""
278
+ if self.src_dict is not None:
279
+ src_str = self.src_dict.string(
280
+ src_tokens, self.cfg.common_eval.post_process
281
+ )
282
+
283
+ # Process top predictions
284
+ for hypo in hypos[: min(len(hypos), self.cfg.generation.nbest)]:
285
+ hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
286
+ hypo_tokens=hypo["tokens"].int().cpu(),
287
+ src_str=src_str,
288
+ alignment=hypo["alignment"],
289
+ align_dict=self.align_dict,
290
+ tgt_dict=self.tgt_dict,
291
+ remove_bpe="subword_nmt",
292
+ extra_symbols_to_ignore=get_symbols_to_strip_from_output(
293
+ self.generator
294
+ ),
295
+ )
296
+ detok_hypo_str = self.decode_fn(hypo_str)
297
+ final_translations.append(detok_hypo_str)
298
+ return final_translations
inference/engine.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import truncate
2
+ from sacremoses import MosesPunctNormalizer
3
+ from sacremoses import MosesTokenizer
4
+ from sacremoses import MosesDetokenizer
5
+ from subword_nmt.apply_bpe import BPE, read_vocabulary
6
+ import codecs
7
+ from tqdm import tqdm
8
+ from indicnlp.tokenize import indic_tokenize
9
+ from indicnlp.tokenize import indic_detokenize
10
+ from indicnlp.normalize import indic_normalize
11
+ from indicnlp.transliterate import unicode_transliterate
12
+ from mosestokenizer import MosesSentenceSplitter
13
+ from indicnlp.tokenize import sentence_tokenize
14
+
15
+ from inference.custom_interactive import Translator
16
+
17
+
18
+ INDIC = ["as", "bn", "gu", "hi", "kn", "ml", "mr", "or", "pa", "ta", "te"]
19
+
20
+
21
+ def split_sentences(paragraph, language):
22
+ if language == "en":
23
+ with MosesSentenceSplitter(language) as splitter:
24
+ return splitter([paragraph])
25
+ elif language in INDIC:
26
+ return sentence_tokenize.sentence_split(paragraph, lang=language)
27
+
28
+
29
+ def add_token(sent, tag_infos):
30
+ """add special tokens specified by tag_infos to each element in list
31
+
32
+ tag_infos: list of tuples (tag_type,tag)
33
+
34
+ each tag_info results in a token of the form: __{tag_type}__{tag}__
35
+
36
+ """
37
+
38
+ tokens = []
39
+ for tag_type, tag in tag_infos:
40
+ token = "__" + tag_type + "__" + tag + "__"
41
+ tokens.append(token)
42
+
43
+ return " ".join(tokens) + " " + sent
44
+
45
+
46
+ def apply_lang_tags(sents, src_lang, tgt_lang):
47
+ tagged_sents = []
48
+ for sent in sents:
49
+ tagged_sent = add_token(sent.strip(), [("src", src_lang), ("tgt", tgt_lang)])
50
+ tagged_sents.append(tagged_sent)
51
+ return tagged_sents
52
+
53
+
54
+ def truncate_long_sentences(sents):
55
+
56
+ MAX_SEQ_LEN = 200
57
+ new_sents = []
58
+
59
+ for sent in sents:
60
+ words = sent.split()
61
+ num_words = len(words)
62
+ if num_words > MAX_SEQ_LEN:
63
+ print_str = " ".join(words[:5]) + " .... " + " ".join(words[-5:])
64
+ sent = " ".join(words[:MAX_SEQ_LEN])
65
+ print(
66
+ f"WARNING: Sentence {print_str} truncated to 200 tokens as it exceeds maximum length limit"
67
+ )
68
+
69
+ new_sents.append(sent)
70
+ return new_sents
71
+
72
+
73
+ class Model:
74
+ def __init__(self, expdir):
75
+ self.expdir = expdir
76
+ self.en_tok = MosesTokenizer(lang="en")
77
+ self.en_normalizer = MosesPunctNormalizer()
78
+ self.en_detok = MosesDetokenizer(lang="en")
79
+ self.xliterator = unicode_transliterate.UnicodeIndicTransliterator()
80
+ print("Initializing vocab and bpe")
81
+ self.vocabulary = read_vocabulary(
82
+ codecs.open(f"{expdir}/vocab/vocab.SRC", encoding="utf-8"), 5
83
+ )
84
+ self.bpe = BPE(
85
+ codecs.open(f"{expdir}/vocab/bpe_codes.32k.SRC", encoding="utf-8"),
86
+ -1,
87
+ "@@",
88
+ self.vocabulary,
89
+ None,
90
+ )
91
+
92
+ print("Initializing model for translation")
93
+ # initialize the model
94
+ self.translator = Translator(
95
+ f"{expdir}/final_bin", f"{expdir}/model/checkpoint_best.pt", batch_size=100
96
+ )
97
+
98
+ # translate a batch of sentences from src_lang to tgt_lang
99
+ def batch_translate(self, batch, src_lang, tgt_lang):
100
+
101
+ assert isinstance(batch, list)
102
+ preprocessed_sents = self.preprocess(batch, lang=src_lang)
103
+ bpe_sents = self.apply_bpe(preprocessed_sents)
104
+ tagged_sents = apply_lang_tags(bpe_sents, src_lang, tgt_lang)
105
+ tagged_sents = truncate_long_sentences(tagged_sents)
106
+
107
+ translations = self.translator.translate(tagged_sents)
108
+ postprocessed_sents = self.postprocess(translations, tgt_lang)
109
+
110
+ return postprocessed_sents
111
+
112
+ # translate a paragraph from src_lang to tgt_lang
113
+ def translate_paragraph(self, paragraph, src_lang, tgt_lang):
114
+
115
+ assert isinstance(paragraph, str)
116
+ sents = split_sentences(paragraph, src_lang)
117
+
118
+ postprocessed_sents = self.batch_translate(sents, src_lang, tgt_lang)
119
+
120
+ translated_paragraph = " ".join(postprocessed_sents)
121
+
122
+ return translated_paragraph
123
+
124
+ def preprocess_sent(self, sent, normalizer, lang):
125
+ if lang == "en":
126
+ return " ".join(
127
+ self.en_tok.tokenize(
128
+ self.en_normalizer.normalize(sent.strip()), escape=False
129
+ )
130
+ )
131
+ else:
132
+ # line = indic_detokenize.trivial_detokenize(line.strip(), lang)
133
+ return unicode_transliterate.UnicodeIndicTransliterator.transliterate(
134
+ " ".join(
135
+ indic_tokenize.trivial_tokenize(
136
+ normalizer.normalize(sent.strip()), lang
137
+ )
138
+ ),
139
+ lang,
140
+ "hi",
141
+ ).replace(" ् ", "्")
142
+
143
+ def preprocess(self, sents, lang):
144
+ """
145
+ Normalize, tokenize and script convert(for Indic)
146
+ return number of sentences input file
147
+
148
+ """
149
+
150
+ if lang == "en":
151
+
152
+ # processed_sents = Parallel(n_jobs=-1, backend="multiprocessing")(
153
+ # delayed(preprocess_line)(line, None, lang) for line in tqdm(sents, total=num_lines)
154
+ # )
155
+ processed_sents = [
156
+ self.preprocess_sent(line, None, lang) for line in tqdm(sents)
157
+ ]
158
+
159
+ else:
160
+ normfactory = indic_normalize.IndicNormalizerFactory()
161
+ normalizer = normfactory.get_normalizer(lang)
162
+
163
+ # processed_sents = Parallel(n_jobs=-1, backend="multiprocessing")(
164
+ # delayed(preprocess_line)(line, normalizer, lang) for line in tqdm(infile, total=num_lines)
165
+ # )
166
+ processed_sents = [
167
+ self.preprocess_sent(line, normalizer, lang) for line in tqdm(sents)
168
+ ]
169
+
170
+ return processed_sents
171
+
172
+ def postprocess(self, sents, lang, common_lang="hi"):
173
+ """
174
+ parse fairseq interactive output, convert script back to native Indic script (in case of Indic languages) and detokenize.
175
+
176
+ infname: fairseq log file
177
+ outfname: output file of translation (sentences not translated contain the dummy string 'DUMMY_OUTPUT'
178
+ input_size: expected number of output sentences
179
+ lang: language
180
+ """
181
+ postprocessed_sents = []
182
+
183
+ if lang == "en":
184
+ for sent in sents:
185
+ # outfile.write(en_detok.detokenize(sent.split(" ")) + "\n")
186
+ postprocessed_sents.append(self.en_detok.detokenize(sent.split(" ")))
187
+ else:
188
+ for sent in sents:
189
+ outstr = indic_detokenize.trivial_detokenize(
190
+ self.xliterator.transliterate(sent, common_lang, lang), lang
191
+ )
192
+ # outfile.write(outstr + "\n")
193
+ postprocessed_sents.append(outstr)
194
+ return postprocessed_sents
195
+
196
+ def apply_bpe(self, sents):
197
+
198
+ return [self.bpe.process_line(sent) for sent in sents]
legacy/apply_bpe_test_valid_notag.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ expdir=$1 # EXPDIR
4
+ org_data_dir=$2
5
+ langs=$3
6
+
7
+ #`dirname $0`/env.sh
8
+ SUBWORD_NMT_DIR="subword-nmt"
9
+ echo "Apply to each language"
10
+
11
+ for dset in `echo test dev`
12
+ do
13
+ echo $dset
14
+
15
+ in_dset_dir="$org_data_dir/$dset"
16
+ out_dset_dir="$expdir/bpe/$dset"
17
+
18
+ for lang in $langs
19
+ do
20
+
21
+ echo Apply BPE for $dset "-" $lang
22
+
23
+ mkdir -p $out_dset_dir
24
+
25
+ python $SUBWORD_NMT_DIR/subword_nmt/apply_bpe.py \
26
+ -c $expdir/vocab/bpe_codes.32k.SRC_TGT \
27
+ --vocabulary $expdir/vocab/vocab.SRC \
28
+ --vocabulary-threshold 5 \
29
+ < $in_dset_dir/$dset.$lang \
30
+ > $out_dset_dir/$dset.$lang
31
+
32
+ done
33
+ done
legacy/apply_bpe_train_notag.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ expdir=$1 # EXPDIR
4
+
5
+ #`dirname $0`/env.sh
6
+ SUBWORD_NMT_DIR="subword-nmt"
7
+
8
+ data_dir="$expdir/data"
9
+ train_file=$data_dir/train
10
+ bpe_file=$expdir/bpe/train/train
11
+
12
+ mkdir -p $expdir/bpe/train
13
+
14
+ echo "Apply to SRC corpus"
15
+
16
+ python $SUBWORD_NMT_DIR/subword_nmt/apply_bpe.py \
17
+ -c $expdir/vocab/bpe_codes.32k.SRC_TGT \
18
+ --vocabulary $expdir/vocab/vocab.SRC \
19
+ --vocabulary-threshold 5 \
20
+ --num-workers "-1" \
21
+ < $train_file.SRC \
22
+ > $bpe_file.SRC
23
+
24
+ echo "Apply to TGT corpus"
25
+
26
+ python $SUBWORD_NMT_DIR/subword_nmt/apply_bpe.py \
27
+ -c $expdir/vocab/bpe_codes.32k.SRC_TGT \
28
+ --vocabulary $expdir/vocab/vocab.TGT \
29
+ --vocabulary-threshold 5 \
30
+ --num-workers "-1" \
31
+ < $train_file.TGT \
32
+ > $bpe_file.TGT
33
+
legacy/env.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export SRC=''
3
+
4
+ ## Python env directory where fairseq is installed
5
+ export PYTHON_ENV=''
6
+
7
+ export SUBWORD_NMT_DIR=''
8
+ export INDIC_RESOURCES_PATH=''
9
+ export INDIC_NLP_HOME=''
10
+
11
+ export CUDA_HOME=''
12
+
13
+ export PATH=$CUDA_HOME/bin:$INDIC_NLP_HOME:$PATH
14
+ export LD_LIBRARY_PATH=$CUDA_HOME/lib64
15
+
16
+ # set environment variable to control GPUS visible to the application
17
+ #export CUDA_VISIBLE_DEVICES="'
legacy/indictrans_workflow.ipynb ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import random\n",
11
+ "from tqdm.notebook import tqdm\n",
12
+ "from sacremoses import MosesPunctNormalizer\n",
13
+ "from sacremoses import MosesTokenizer\n",
14
+ "from sacremoses import MosesDetokenizer\n",
15
+ "from collections import defaultdict\n",
16
+ "import sacrebleu"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "# The path to the local git repo for Indic NLP library\n",
26
+ "INDIC_NLP_LIB_HOME=\"\"\n",
27
+ "\n",
28
+ "# The path to the local git repo for Indic NLP Resources\n",
29
+ "INDIC_NLP_RESOURCES=\"\"\n",
30
+ "\n",
31
+ "import sys\n",
32
+ "sys.path.append(r'{}'.format(INDIC_NLP_LIB_HOME))\n",
33
+ "\n",
34
+ "from indicnlp import common\n",
35
+ "common.set_resources_path(INDIC_NLP_RESOURCES)\n",
36
+ "\n",
37
+ "from indicnlp import loader\n",
38
+ "loader.load()"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "import indicnlp\n",
48
+ "from indicnlp.tokenize import indic_tokenize\n",
49
+ "from indicnlp.tokenize import indic_detokenize\n",
50
+ "from indicnlp.normalize import indic_normalize\n",
51
+ "from indicnlp.transliterate import unicode_transliterate"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "LANGS=[\n",
61
+ " \"bn\",\n",
62
+ " \"gu\",\n",
63
+ " \"hi\",\n",
64
+ " \"kn\",\n",
65
+ " \"ml\",\n",
66
+ " \"mr\",\n",
67
+ " \"or\",\n",
68
+ " \"pa\",\n",
69
+ " \"ta\",\n",
70
+ " \"te\", \n",
71
+ "]"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "def preprocess(infname,outfname,lang):\n",
81
+ " \"\"\"\n",
82
+ " Preparing each corpus file: \n",
83
+ " - Normalization\n",
84
+ " - Tokenization \n",
85
+ " - Script coversion to Devanagari for Indic scripts\n",
86
+ " \"\"\"\n",
87
+ " \n",
88
+ " ### reading \n",
89
+ " with open(infname,'r',encoding='utf-8') as infile, \\\n",
90
+ " open(outfname,'w',encoding='utf-8') as outfile:\n",
91
+ " \n",
92
+ " if lang=='en':\n",
93
+ " en_tok=MosesTokenizer(lang='en')\n",
94
+ " en_normalizer = MosesPunctNormalizer()\n",
95
+ " for line in tqdm(infile): \n",
96
+ " outline=' '.join(\n",
97
+ " en_tok.tokenize( \n",
98
+ " en_normalizer.normalize(line.strip()), \n",
99
+ " escape=False ) )\n",
100
+ " outfile.write(outline+'\\n')\n",
101
+ " \n",
102
+ " else:\n",
103
+ " normfactory=indic_normalize.IndicNormalizerFactory()\n",
104
+ " normalizer=normfactory.get_normalizer(lang)\n",
105
+ " for line in tqdm(infile): \n",
106
+ " outline=unicode_transliterate.UnicodeIndicTransliterator.transliterate(\n",
107
+ " ' '.join(\n",
108
+ " indic_tokenize.trivial_tokenize(\n",
109
+ " normalizer.normalize(line.strip()), lang) ), lang, 'hi').replace(' ् ','्')\n",
110
+ "\n",
111
+ "\n",
112
+ " outfile.write(outline+'\\n')"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "def add_token(sent, tag_infos):\n",
122
+ " \"\"\" add special tokens specified by tag_infos to each element in list\n",
123
+ "\n",
124
+ " tag_infos: list of tuples (tag_type,tag)\n",
125
+ "\n",
126
+ " each tag_info results in a token of the form: __{tag_type}__{tag}__\n",
127
+ "\n",
128
+ " \"\"\"\n",
129
+ "\n",
130
+ " tokens=[]\n",
131
+ " for tag_type, tag in tag_infos:\n",
132
+ " token = '__' + tag_type + '__' + tag + '__'\n",
133
+ " tokens.append(token)\n",
134
+ "\n",
135
+ " return ' '.join(tokens) + ' ' + sent \n",
136
+ "\n",
137
+ "\n",
138
+ "def concat_data(data_dir, outdir, lang_pair_list, out_src_lang='SRC', out_trg_lang='TGT'):\n",
139
+ " \"\"\"\n",
140
+ " data_dir: input dir, contains directories for language pairs named l1-l2\n",
141
+ " \"\"\"\n",
142
+ " os.makedirs(outdir,exist_ok=True)\n",
143
+ "\n",
144
+ " out_src_fname='{}/train.{}'.format(outdir,out_src_lang)\n",
145
+ " out_trg_fname='{}/train.{}'.format(outdir,out_trg_lang)\n",
146
+ "# out_meta_fname='{}/metadata.txt'.format(outdir)\n",
147
+ "\n",
148
+ " print()\n",
149
+ " print(out_src_fname)\n",
150
+ " print(out_trg_fname)\n",
151
+ "# print(out_meta_fname)\n",
152
+ "\n",
153
+ " ### concatenate train data \n",
154
+ " if os.path.isfile(out_src_fname):\n",
155
+ " os.unlink(out_src_fname)\n",
156
+ " if os.path.isfile(out_trg_fname):\n",
157
+ " os.unlink(out_trg_fname)\n",
158
+ "# if os.path.isfile(out_meta_fname):\n",
159
+ "# os.unlink(out_meta_fname)\n",
160
+ "\n",
161
+ " for src_lang, trg_lang in tqdm(lang_pair_list):\n",
162
+ " print('src: {}, tgt:{}'.format(src_lang,trg_lang)) \n",
163
+ "\n",
164
+ " in_src_fname='{}/{}-{}/train.{}'.format(data_dir,src_lang,trg_lang,src_lang)\n",
165
+ " in_trg_fname='{}/{}-{}/train.{}'.format(data_dir,src_lang,trg_lang,trg_lang)\n",
166
+ "\n",
167
+ " print(in_src_fname)\n",
168
+ " os.system('cat {} >> {}'.format(in_src_fname,out_src_fname))\n",
169
+ "\n",
170
+ " print(in_trg_fname)\n",
171
+ " os.system('cat {} >> {}'.format(in_trg_fname,out_trg_fname)) \n",
172
+ " \n",
173
+ " \n",
174
+ "# with open('{}/lang_pairs.txt'.format(outdir),'w',encoding='utf-8') as lpfile: \n",
175
+ "# lpfile.write('\\n'.join( [ '-'.join(x) for x in lang_pair_list ] ))\n",
176
+ " \n",
177
+ " corpus_stats(data_dir, outdir, lang_pair_list)\n",
178
+ " \n",
179
+ "def corpus_stats(data_dir, outdir, lang_pair_list):\n",
180
+ " \"\"\"\n",
181
+ " data_dir: input dir, contains directories for language pairs named l1-l2\n",
182
+ " \"\"\"\n",
183
+ "\n",
184
+ " with open('{}/lang_pairs.txt'.format(outdir),'w',encoding='utf-8') as lpfile: \n",
185
+ "\n",
186
+ " for src_lang, trg_lang in tqdm(lang_pair_list):\n",
187
+ " print('src: {}, tgt:{}'.format(src_lang,trg_lang)) \n",
188
+ "\n",
189
+ " in_src_fname='{}/{}-{}/train.{}'.format(data_dir,src_lang,trg_lang,src_lang)\n",
190
+ " # in_trg_fname='{}/{}-{}/train.{}'.format(data_dir,src_lang,trg_lang,trg_lang)\n",
191
+ "\n",
192
+ " print(in_src_fname)\n",
193
+ " corpus_size=0\n",
194
+ " with open(in_src_fname,'r',encoding='utf-8') as infile:\n",
195
+ " corpus_size=sum(map(lambda x:1,infile))\n",
196
+ " \n",
197
+ " lpfile.write('{}\\t{}\\t{}\\n'.format(src_lang,trg_lang,corpus_size))\n",
198
+ " \n",
199
+ "def generate_lang_tag_iterator(infname):\n",
200
+ " with open(infname,'r',encoding='utf-8') as infile:\n",
201
+ " for line in infile:\n",
202
+ " src,tgt,count=line.strip().split('\\t')\n",
203
+ " count=int(count)\n",
204
+ " for _ in range(count):\n",
205
+ " yield (src,tgt) "
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "execution_count": null,
211
+ "metadata": {},
212
+ "outputs": [],
213
+ "source": [
214
+ "#### directory containing all experiments \n",
215
+ "## one directory per experiment \n",
216
+ "EXPBASEDIR=''\n",
217
+ "\n",
218
+ "### directory containing data\n",
219
+ "## contains 3 directories: train test dev\n",
220
+ "## train directory structure: \n",
221
+ "## - There is one directory for each language pair\n",
222
+ "## - Directory naming convention lang1-lang2 (you need another directory/softlink for lang2-lang1)\n",
223
+ "## - Each directory contains 6 files: {train,test,dev}.{lang1,lang2}\n",
224
+ "## test & dev directory structure \n",
225
+ "## - test: contains files {test.l1,test.l2,test.l3} - assumes parallel test files like the wat2021 dataset\n",
226
+ "## - valid: contains files {dev.l1,dev.l2,dev.l3} - assumes parallel test files like the wat2021 dataset\n",
227
+ "## All files are tokenized\n",
228
+ "ORG_DATA_DIR='{d}/consolidated_unique_preprocessed'.format(d=BASEDIR)\n",
229
+ "\n"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "markdown",
234
+ "metadata": {},
235
+ "source": [
236
+ "# Exp2 (M2O)\n",
237
+ "\n",
238
+ "- All *-en "
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "markdown",
243
+ "metadata": {},
244
+ "source": [
245
+ "**Params**"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": null,
251
+ "metadata": {},
252
+ "outputs": [],
253
+ "source": [
254
+ "expname='exp2_m2o_baseline'\n",
255
+ "expdir='{}/{}'.format(EXPBASEDIR,expname)\n",
256
+ "\n",
257
+ "lang_pair_list=[]\n",
258
+ "for lang in LANGS: \n",
259
+ " lang_pair_list.append([lang,'en'])"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "markdown",
264
+ "metadata": {},
265
+ "source": [
266
+ "**Create Train Corpus**"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": null,
272
+ "metadata": {},
273
+ "outputs": [],
274
+ "source": [
275
+ "indir='{}/train'.format(ORG_DATA_DIR)\n",
276
+ "outdir='{}/data'.format(expdir)\n",
277
+ "\n",
278
+ "# print(lang_pair_list)\n",
279
+ "concat_data(indir,outdir,lang_pair_list)"
280
+ ]
281
+ },
282
+ {
283
+ "cell_type": "markdown",
284
+ "metadata": {},
285
+ "source": [
286
+ "**Learn BPE**"
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "code",
291
+ "execution_count": null,
292
+ "metadata": {},
293
+ "outputs": [],
294
+ "source": [
295
+ "!echo ./learn_bpe.sh {expdir}"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": null,
301
+ "metadata": {},
302
+ "outputs": [],
303
+ "source": [
304
+ "!echo ./apply_bpe_train_notag.sh {expdir}"
305
+ ]
306
+ },
307
+ {
308
+ "cell_type": "code",
309
+ "execution_count": null,
310
+ "metadata": {},
311
+ "outputs": [],
312
+ "source": [
313
+ "!echo ./apply_bpe_test_valid_notag.sh {expdir} {ORG_DATA_DIR} {'\"'+' '.join(LANGS+['en'])+'\"'}"
314
+ ]
315
+ },
316
+ {
317
+ "cell_type": "markdown",
318
+ "metadata": {},
319
+ "source": [
320
+ "**Add language tags to train**"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "execution_count": null,
326
+ "metadata": {},
327
+ "outputs": [],
328
+ "source": [
329
+ "dset='train' \n",
330
+ "\n",
331
+ "src_fname='{expdir}/bpe/train/{dset}.SRC'.format(expdir=expdir,dset=dset)\n",
332
+ "tgt_fname='{expdir}/bpe/train/{dset}.TGT'.format(expdir=expdir,dset=dset)\n",
333
+ "meta_fname='{expdir}/data/lang_pairs.txt'.format(expdir=expdir,dset=dset)\n",
334
+ " \n",
335
+ "out_src_fname='{expdir}/final/{dset}.SRC'.format(expdir=expdir,dset=dset)\n",
336
+ "out_tgt_fname='{expdir}/final/{dset}.TGT'.format(expdir=expdir,dset=dset)\n",
337
+ "\n",
338
+ "lang_tag_iterator=generate_lang_tag_iterator(meta_fname)\n",
339
+ "\n",
340
+ "print(expdir)\n",
341
+ "os.makedirs('{expdir}/final'.format(expdir=expdir),exist_ok=True)\n",
342
+ "\n",
343
+ "with open(src_fname,'r',encoding='utf-8') as srcfile, \\\n",
344
+ " open(tgt_fname,'r',encoding='utf-8') as tgtfile, \\\n",
345
+ " open(out_src_fname,'w',encoding='utf-8') as outsrcfile, \\\n",
346
+ " open(out_tgt_fname,'w',encoding='utf-8') as outtgtfile: \n",
347
+ "\n",
348
+ " for (l1,l2), src_sent, tgt_sent in tqdm(zip(lang_tag_iterator, srcfile, tgtfile)):\n",
349
+ " outsrcfile.write(add_token(src_sent.strip(),[('src',l1),('tgt',l2)]) + '\\n' )\n",
350
+ " outtgtfile.write(tgt_sent.strip()+'\\n')"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "markdown",
355
+ "metadata": {},
356
+ "source": [
357
+ "**Add language tags to valid**\n",
358
+ "\n",
359
+ "- add language tags, create parallel corpus\n",
360
+ "- sample 20\\% for validation set \n",
361
+ "- Create final validation set"
362
+ ]
363
+ },
364
+ {
365
+ "cell_type": "code",
366
+ "execution_count": null,
367
+ "metadata": {},
368
+ "outputs": [],
369
+ "source": [
370
+ "dset='dev' \n",
371
+ "out_src_fname='{expdir}/final/{dset}.SRC'.format(\n",
372
+ " expdir=expdir,dset=dset)\n",
373
+ "out_tgt_fname='{expdir}/final/{dset}.TGT'.format(\n",
374
+ " expdir=expdir,dset=dset)\n",
375
+ "\n",
376
+ "os.makedirs('{expdir}/final'.format(expdir=expdir),exist_ok=True)\n",
377
+ "\n",
378
+ "print('Processing validation files') \n",
379
+ "consolidated_dset=[]\n",
380
+ "for l1, l2 in tqdm(lang_pair_list):\n",
381
+ " src_fname='{expdir}/bpe/{dset}/{dset}.{lang}'.format(\n",
382
+ " expdir=expdir,dset=dset,lang=l1)\n",
383
+ " tgt_fname='{expdir}/bpe/{dset}/{dset}.{lang}'.format(\n",
384
+ " expdir=expdir,dset=dset,lang=l2)\n",
385
+ "# print(src_fname)\n",
386
+ "# print(os.path.exists(src_fname))\n",
387
+ " with open(src_fname,'r',encoding='utf-8') as srcfile, \\\n",
388
+ " open(tgt_fname,'r',encoding='utf-8') as tgtfile:\n",
389
+ " for src_sent, tgt_sent in zip(srcfile,tgtfile):\n",
390
+ " consolidated_dset.append(\n",
391
+ " ( add_token(src_sent.strip(),[('src',l1),('tgt',l2)]),\n",
392
+ " tgt_sent.strip() )\n",
393
+ " )\n",
394
+ "\n",
395
+ "print('Create validation set') \n",
396
+ "random.shuffle(consolidated_dset)\n",
397
+ "final_set=consolidated_dset[:len(consolidated_dset)//5] \n",
398
+ "\n",
399
+ "print('Original set size: {}'.format(len(consolidated_dset))) \n",
400
+ "print('Sampled set size: {}'.format(len(final_set))) \n",
401
+ "\n",
402
+ "print('Write validation set')\n",
403
+ "\n",
404
+ "with open(out_src_fname,'w',encoding='utf-8') as srcfile, \\\n",
405
+ " open(out_tgt_fname,'w',encoding='utf-8') as tgtfile:\n",
406
+ " for src_sent, tgt_sent in final_set: \n",
407
+ " srcfile.write(src_sent+'\\n')\n",
408
+ " tgtfile.write(tgt_sent+'\\n')\n"
409
+ ]
410
+ },
411
+ {
412
+ "cell_type": "markdown",
413
+ "metadata": {},
414
+ "source": [
415
+ "**Add language tags to test**\n",
416
+ "\n",
417
+ "- add language tags, create parallel corpus all M2O language pairs \n",
418
+ "- Create final test set"
419
+ ]
420
+ },
421
+ {
422
+ "cell_type": "code",
423
+ "execution_count": null,
424
+ "metadata": {},
425
+ "outputs": [],
426
+ "source": [
427
+ "dset='test' \n",
428
+ "out_src_fname='{expdir}/final/{dset}.SRC'.format(\n",
429
+ " expdir=expdir,dset=dset)\n",
430
+ "out_tgt_fname='{expdir}/final/{dset}.TGT'.format(\n",
431
+ " expdir=expdir,dset=dset)\n",
432
+ "\n",
433
+ "os.makedirs('{expdir}/final'.format(expdir=expdir),exist_ok=True)\n",
434
+ "\n",
435
+ "print('Processing test files') \n",
436
+ "consolidated_dset=[]\n",
437
+ "for l1, l2 in tqdm(lang_pair_list):\n",
438
+ " src_fname='{expdir}/bpe/{dset}/{dset}.{lang}'.format(\n",
439
+ " expdir=expdir,dset=dset,lang=l1)\n",
440
+ " tgt_fname='{expdir}/bpe/{dset}/{dset}.{lang}'.format(\n",
441
+ " expdir=expdir,dset=dset,lang=l2)\n",
442
+ "# print(src_fname)\n",
443
+ "# print(os.path.exists(src_fname))\n",
444
+ " with open(src_fname,'r',encoding='utf-8') as srcfile, \\\n",
445
+ " open(tgt_fname,'r',encoding='utf-8') as tgtfile:\n",
446
+ " for src_sent, tgt_sent in zip(srcfile,tgtfile):\n",
447
+ " consolidated_dset.append(\n",
448
+ " ( add_token(src_sent.strip(),[('src',l1),('tgt',l2)]),\n",
449
+ " tgt_sent.strip() )\n",
450
+ " )\n",
451
+ "\n",
452
+ "print('Final set size: {}'.format(len(consolidated_dset))) \n",
453
+ " \n",
454
+ "print('Write test set')\n",
455
+ "print('testset truncated')\n",
456
+ "\n",
457
+ "with open(out_src_fname,'w',encoding='utf-8') as srcfile, \\\n",
458
+ " open(out_tgt_fname,'w',encoding='utf-8') as tgtfile:\n",
459
+ " for lno, (src_sent, tgt_sent) in enumerate(consolidated_dset,1):\n",
460
+ " \n",
461
+ " s=src_sent.strip().split(' ')\n",
462
+ " t=tgt_sent.strip().split(' ')\n",
463
+ " \n",
464
+ " if len(s) > 200 or len(t) > 200:\n",
465
+ " print('exp: {}, pair: ({},{}), lno: {}: lens: ({},{})'.format(expname,l1,l2,lno,len(s),len(t))) \n",
466
+ " \n",
467
+ " src_sent=' '.join( s[:min(len(s),200)] )\n",
468
+ " tgt_sent=' '.join( t[:min(len(t),200)] )\n",
469
+ " \n",
470
+ " srcfile.write(src_sent+'\\n')\n",
471
+ " tgtfile.write(tgt_sent+'\\n')"
472
+ ]
473
+ },
474
+ {
475
+ "cell_type": "markdown",
476
+ "metadata": {},
477
+ "source": [
478
+ "**Binarize data**"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "execution_count": null,
484
+ "metadata": {},
485
+ "outputs": [],
486
+ "source": [
487
+ "!echo ./binarize_training_exp.sh {expdir} SRC TGT"
488
+ ]
489
+ },
490
+ {
491
+ "cell_type": "markdown",
492
+ "metadata": {},
493
+ "source": [
494
+ "**Training Command**"
495
+ ]
496
+ },
497
+ {
498
+ "cell_type": "code",
499
+ "execution_count": null,
500
+ "metadata": {},
501
+ "outputs": [],
502
+ "source": [
503
+ "%%bash \n",
504
+ "\n",
505
+ "python train.py {expdir}/final_bin \\\n",
506
+ " --arch transformer \\\n",
507
+ " --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 1.0 \\\n",
508
+ " --lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \\\n",
509
+ " --dropout 0.2 \\\n",
510
+ " --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \\\n",
511
+ " --max-tokens 8192 \\\n",
512
+ " --max-update 1000000 \\\n",
513
+ " --max-source-positions 200 \\\n",
514
+ " --max-target-positions 200 \\\n",
515
+ " --tensorboard-logdir {expdir}/tensorboard \\\n",
516
+ " --save-dir {expdir}/model \\\n",
517
+ " --required-batch-size-multiple 8 \\\n",
518
+ " --save-interval 1 \\\n",
519
+ " --keep-last-epochs 5 \\\n",
520
+ " --patience 5 \\\n",
521
+ " --fp16"
522
+ ]
523
+ },
524
+ {
525
+ "cell_type": "markdown",
526
+ "metadata": {},
527
+ "source": [
528
+ "**Cleanup**"
529
+ ]
530
+ },
531
+ {
532
+ "cell_type": "code",
533
+ "execution_count": null,
534
+ "metadata": {},
535
+ "outputs": [],
536
+ "source": [
537
+ "# os.unlink('{}')\n",
538
+ "\n",
539
+ "to_delete=[\n",
540
+ " '{expdir}/data/train.SRC'.format(expdir=expdir,dset=dset),\n",
541
+ " '{expdir}/data/train.TGT'.format(expdir=expdir,dset=dset),\n",
542
+ " '{expdir}/bpe/train/train.SRC'.format(expdir=expdir,dset=dset),\n",
543
+ " '{expdir}/bpe/train/train.TGT'.format(expdir=expdir,dset=dset),\n",
544
+ "]`\n",
545
+ "\n",
546
+ "for fname in to_delete:\n",
547
+ " os.unlink(fname)"
548
+ ]
549
+ },
550
+ {
551
+ "cell_type": "markdown",
552
+ "metadata": {},
553
+ "source": [
554
+ "**Evaluation**"
555
+ ]
556
+ },
557
+ {
558
+ "cell_type": "code",
559
+ "execution_count": null,
560
+ "metadata": {},
561
+ "outputs": [],
562
+ "source": [
563
+ "dset='test' \n",
564
+ "consolidated_testoutput_fname='{expdir}/evaluations/test/default/test.SRC_TGT.TGT'.format(expdir=expdir)\n",
565
+ "consolidated_testoutput_log_fname='{}.log'.format(consolidated_testoutput_fname)\n",
566
+ "metrics_fname='{expdir}/evaluations/test/default/test.metrics.tsv'.format(expdir=expdir)\n",
567
+ " \n",
568
+ "test_set_size=2390\n",
569
+ "\n",
570
+ "consolidated_testoutput=[]\n",
571
+ "with open(consolidated_testoutput_log_fname,'r',encoding='utf-8') as hypfile:\n",
572
+ " consolidated_testoutput= list(map(lambda x: x.strip(), filter(lambda x: x.startswith('H-'),hypfile) ))\n",
573
+ " consolidated_testoutput.sort(key=lambda x: int(x.split('\\t')[0].split('-')[1]))\n",
574
+ " consolidated_testoutput=[ x.split('\\t')[2] for x in consolidated_testoutput ]\n",
575
+ "\n",
576
+ "os.makedirs('{expdir}/evaluations/test/default'.format(expdir=expdir),exist_ok=True)\n",
577
+ "\n",
578
+ "with open(consolidated_testoutput_fname,'w',encoding='utf-8') as finalhypfile:\n",
579
+ " for sent in consolidated_testoutput:\n",
580
+ " finalhypfile.write(sent+'\\n')\n",
581
+ "\n",
582
+ "print('Processing test files') \n",
583
+ "with open(metrics_fname,'w',encoding='utf-8') as metrics_file: \n",
584
+ " for i, (l1, l2) in enumerate(tqdm(lang_pair_list)):\n",
585
+ "\n",
586
+ " start=i*test_set_size\n",
587
+ " end=(i+1)*test_set_size\n",
588
+ " hyps=consolidated_testoutput[start:end]\n",
589
+ " ref_fname='{expdir}/{dset}/{dset}.{lang}'.format(\n",
590
+ " expdir=ORG_DATA_DIR,dset=dset,lang=l2)\n",
591
+ "\n",
592
+ " refs=[]\n",
593
+ " with open(ref_fname,'r',encoding='utf-8') as reffile:\n",
594
+ " refs.extend(map(lambda x:x.strip(),reffile))\n",
595
+ "\n",
596
+ " assert(len(hyps)==len(refs))\n",
597
+ "\n",
598
+ " bleu=sacrebleu.corpus_bleu(hyps,[refs],tokenize='none')\n",
599
+ "\n",
600
+ " print('{} {} {} {}'.format(l1,l2,bleu.score,bleu.prec_str))\n",
601
+ " metrics_file.write('{}\\t{}\\t{}\\t{}\\t{}\\n'.format(expname,l1,l2,bleu.score,bleu.prec_str))\n",
602
+ " "
603
+ ]
604
+ }
605
+ ],
606
+ "metadata": {
607
+ "kernelspec": {
608
+ "display_name": "Python 3",
609
+ "language": "python",
610
+ "name": "python3"
611
+ },
612
+ "language_info": {
613
+ "codemirror_mode": {
614
+ "name": "ipython",
615
+ "version": 3
616
+ },
617
+ "file_extension": ".py",
618
+ "mimetype": "text/x-python",
619
+ "name": "python",
620
+ "nbconvert_exporter": "python",
621
+ "pygments_lexer": "ipython3",
622
+ "version": "3.7.0"
623
+ },
624
+ "toc": {
625
+ "base_numbering": 1,
626
+ "nav_menu": {
627
+ "height": "243.993px",
628
+ "width": "160px"
629
+ },
630
+ "number_sections": true,
631
+ "sideBar": true,
632
+ "skip_h1_title": false,
633
+ "title_cell": "Table of Contents",
634
+ "title_sidebar": "Contents",
635
+ "toc_cell": false,
636
+ "toc_position": {},
637
+ "toc_section_display": true,
638
+ "toc_window_display": false
639
+ }
640
+ },
641
+ "nbformat": 4,
642
+ "nbformat_minor": 4
643
+ }
legacy/install_fairseq.sh ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #NVIDIA CUDA download
2
+ wget "https://developer.nvidia.com/compute/cuda/10.0/Prod/local_installers/cuda_10.0.130_410.48_linux"
3
+ wget "http://developer.download.nvidia.com/compute/cuda/10.0/Prod/patches/1/cuda_10.0.130.1_linux.run"
4
+
5
+ ## do not install drivers (See this: https://docs.nvidia.com/deploy/cuda-compatibility/index.html)
6
+ sudo sh "cuda_10.0.130_410.48_linux"
7
+ sudo sh "cuda_10.0.130.1_linux.run"
8
+
9
+ #Set environment variables
10
+ export CUDA_HOME=/usr/local/cuda-10.0
11
+ export PATH=$CUDA_HOME/bin:$PATH
12
+ export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
13
+
14
+ # Install pytorch 1.2
15
+ python3 -m venv pytorch1.2
16
+ source pytorch1.2/bin/activate
17
+ which pip3
18
+ pip3 install torch==1.2.0 torchvision==0.4.0
19
+
20
+ # Install nccl
21
+ git clone https://github.com/NVIDIA/nccl.git
22
+ cd nccl
23
+ make src.build CUDA_HOME=$CUDA_HOME
24
+ sudo apt install build-essential devscripts debhelper fakeroot
25
+ make pkg.debian.build CUDA_HOME=$CUDA_HOME
26
+ sudo dpkg -i build/pkg/deb/libnccl2_2.7.8-1+cuda10.0_amd64.deb
27
+ sudo dpkg -i build/pkg/deb/libnccl-dev_2.7.8-1+cuda10.0_amd64.deb
28
+ sudo apt-get install -f
29
+
30
+ # Install Apex
31
+ git clone https://github.com/NVIDIA/apex
32
+ cd apex
33
+ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
34
+ --global-option="--deprecated_fused_adam" --global-option="--xentropy" \
35
+ --global-option="--fast_multihead_attn" ./
36
+
37
+ # Install PyArrow
38
+ pip install pyarrow
39
+
40
+ # Install fairseq
41
+ pip install --editable ./
42
+
43
+ # Install other dependencies
44
+ pip install sacrebleu
45
+ pip install tensorboardX --user
legacy/run_inference.sh ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ src_lang=${1:-hi}
2
+ tgt_lang=${2:-en}
3
+ bucket_path=${3:-gs://ai4b-anuvaad-nmt/baselines/transformer-base/baselines-${src_lang}-${tgt_lang}}
4
+
5
+ expdir=../baselines/baselines-${src_lang}-${tgt_lang}
6
+
7
+ if [[ -d $expdir ]]
8
+ then
9
+ echo "$expdir exists on your filesystem. Please delete this if you have made some changes to the bucket files and trying to redownload"
10
+ else
11
+ mkdir -p $expdir
12
+ mkdir -p $expdir/model
13
+ cd ../baselines
14
+ gsutil -m cp -r $bucket_path/vocab $expdir
15
+ gsutil -m cp -r $bucket_path/final_bin $expdir
16
+ gsutil -m cp $bucket_path/model/checkpoint_best.pt $expdir/model
17
+ cd ../indicTrans
18
+ fi
19
+
20
+
21
+ if [ $src_lang == 'hi' ] || [ $tgt_lang == 'hi' ]; then
22
+ #TEST_SETS=( wmt-news wat2021-devtest wat2020-devtest anuvaad-legal tico19 sap-documentation-benchmark all)
23
+ TEST_SETS=( wat2021-devtest wat2020-devtest wat-2018 wmt-news )
24
+ elif [ $src_lang == 'ta' ] || [ $tgt_lang == 'ta' ]; then
25
+ # TEST_SETS=( wmt-news wat2021-devtest wat2020-devtest anuvaad-legal tico19 all)
26
+ TEST_SETS=( wat2021-devtest wat2020-devtest wat-2018 wmt-news ufal-ta)
27
+ elif [ $src_lang == 'bn' ] || [ $tgt_lang == 'bn' ]; then
28
+ # TEST_SETS=( wat2021-devtest wat2020-devtest anuvaad-legal tico19 all)
29
+ TEST_SETS=( wat2021-devtest wat2020-devtest wat-2018)
30
+ elif [ $src_lang == 'gu' ] || [ $tgt_lang == 'gu' ]; then
31
+ # TEST_SETS=( wmt-news wat2021-devtest wat2020-devtest all)
32
+ TEST_SETS=( wat2021-devtest wat2020-devtest wmt-news )
33
+ elif [ $src_lang == 'as' ] || [ $tgt_lang == 'as' ]; then
34
+ TEST_SETS=( pmi )
35
+ elif [ $src_lang == 'kn' ] || [ $tgt_lang == 'kn' ]; then
36
+ # TEST_SETS=( wat2021-devtest anuvaad-legal all)
37
+ TEST_SETS=( wat2021-devtest )
38
+ elif [ $src_lang == 'ml' ] || [ $tgt_lang == 'ml' ]; then
39
+ # TEST_SETS=( wat2021-devtest wat2020-devtest anuvaad-legal all)
40
+ TEST_SETS=( wat2021-devtest wat2020-devtest wat-2018)
41
+ elif [ $src_lang == 'mr' ] || [ $tgt_lang == 'mr' ]; then
42
+ # TEST_SETS=( wat2021-devtest wat2020-devtest all)
43
+ TEST_SETS=( wat2021-devtest wat2020-devtest )
44
+ elif [ $src_lang == 'or' ] || [ $tgt_lang == 'or' ]; then
45
+ TEST_SETS=( wat2021-devtest )
46
+ elif [ $src_lang == 'pa' ] || [ $tgt_lang == 'pa' ]; then
47
+ TEST_SETS=( wat2021-devtest )
48
+ elif [ $src_lang == 'te' ] || [ $tgt_lang == 'te' ]; then
49
+ # TEST_SETS=( wat2021-devtest wat2020-devtest anuvaad-legal all )
50
+ TEST_SETS=( wat2021-devtest wat2020-devtest wat-2018)
51
+ fi
52
+
53
+ if [ $src_lang == 'en' ]; then
54
+ indic_lang=$tgt_lang
55
+ else
56
+ indic_lang=$src_lang
57
+ fi
58
+
59
+
60
+ for tset in ${TEST_SETS[@]};do
61
+ echo $tset $src_lang $tgt_lang
62
+ if [ $tset == 'wat2021-devtest' ]; then
63
+ SRC_FILE=${expdir}/benchmarks/$tset/test.$src_lang
64
+ REF_FILE=${expdir}/benchmarks/$tset/test.$tgt_lang
65
+ else
66
+ SRC_FILE=${expdir}/benchmarks/$tset/en-${indic_lang}/test.$src_lang
67
+ REF_FILE=${expdir}/benchmarks/$tset/en-${indic_lang}/test.$tgt_lang
68
+ fi
69
+ RESULTS_DIR=${expdir}/results/$tset
70
+
71
+ mkdir -p $RESULTS_DIR
72
+
73
+ bash translate.sh $SRC_FILE $RESULTS_DIR/${src_lang}-${tgt_lang} $src_lang $tgt_lang $expdir $REF_FILE
74
+ # for newline between different outputs
75
+ echo
76
+ done
77
+ # send the results to the bucket
78
+ gsutil -m cp -r $expdir/results $bucket_path
79
+ # clear up the space in the instance
80
+ # rm -r $expdir
legacy/run_joint_inference.sh ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ src_lang=${1:-en}
2
+ tgt_lang=${2:-indic}
3
+ bucket_path=${3:-gs://ai4b-anuvaad-nmt/models/transformer-4x/indictrans-${src_lang}-${tgt_lang}}
4
+
5
+ mkdir -p ../baselines
6
+ expdir=../baselines/baselines-${src_lang}-${tgt_lang}
7
+
8
+ if [[ -d $expdir ]]
9
+ then
10
+ echo "$expdir exists on your filesystem."
11
+ else
12
+ cd ../baselines
13
+ mkdir -p baselines-${src_lang}-${tgt_lang}/model
14
+ mkdir -p baselines-${src_lang}-${tgt_lang}/final_bin
15
+ cd baselines-${src_lang}-${tgt_lang}/model
16
+ gsutil -m cp $bucket_path/model/checkpoint_best.pt .
17
+ cd ..
18
+ gsutil -m cp $bucket_path/vocab .
19
+ gsutil -m cp $bucket_path/final_bin/dict.* final_bin
20
+ cd ../indicTrans
21
+ fi
22
+
23
+
24
+
25
+
26
+
27
+ if [ $src_lang == 'hi' ] || [ $tgt_lang == 'hi' ]; then
28
+ TEST_SETS=( wmt-news wat2021-devtest wat2020-devtest anuvaad-legal tico19 sap-documentation-benchmark all)
29
+ elif [ $src_lang == 'ta' ] || [ $tgt_lang == 'ta' ]; then
30
+ TEST_SETS=( wmt-news wat2021-devtest wat2020-devtest anuvaad-legal tico19 all)
31
+ elif [ $src_lang == 'bn' ] || [ $tgt_lang == 'bn' ]; then
32
+ TEST_SETS=( wat2021-devtest wat2020-devtest anuvaad-legal tico19 all)
33
+ elif [ $src_lang == 'gu' ] || [ $tgt_lang == 'gu' ]; then
34
+ TEST_SETS=( wmt-news wat2021-devtest wat2020-devtest all)
35
+ elif [ $src_lang == 'as' ] || [ $tgt_lang == 'as' ]; then
36
+ TEST_SETS=( all )
37
+ elif [ $src_lang == 'kn' ] || [ $tgt_lang == 'kn' ]; then
38
+ TEST_SETS=( wat2021-devtest anuvaad-legal all)
39
+ elif [ $src_lang == 'ml' ] || [ $tgt_lang == 'ml' ]; then
40
+ TEST_SETS=( wat2021-devtest wat2020-devtest anuvaad-legal all)
41
+ elif [ $src_lang == 'mr' ] || [ $tgt_lang == 'mr' ]; then
42
+ TEST_SETS=( wat2021-devtest wat2020-devtest all)
43
+ elif [ $src_lang == 'or' ] || [ $tgt_lang == 'or' ]; then
44
+ TEST_SETS=( all )
45
+ elif [ $src_lang == 'pa' ] || [ $tgt_lang == 'pa' ]; then
46
+ TEST_SETS=( all )
47
+ elif [ $src_lang == 'te' ] || [ $tgt_lang == 'te' ]; then
48
+ TEST_SETS=( wat2021-devtest wat2020-devtest anuvaad-legal all )
49
+ fi
50
+
51
+ if [ $src_lang == 'en' ]; then
52
+ indic_lang=$tgt_lang
53
+ else
54
+ indic_lang=$src_lang
55
+ fi
56
+
57
+
58
+ for tset in ${TEST_SETS[@]};do
59
+ echo $tset $src_lang $tgt_lang
60
+ if [ $tset == 'wat2021-devtest' ]; then
61
+ SRC_FILE=${expdir}/devtest/$tset/test.$src_lang
62
+ REF_FILE=${expdir}/devtest/$tset/test.$tgt_lang
63
+ else
64
+ SRC_FILE=${expdir}/devtest/$tset/en-${indic_lang}/test.$src_lang
65
+ REF_FILE=${expdir}/devtest/$tset/en-${indic_lang}/test.$tgt_lang
66
+ fi
67
+ RESULTS_DIR=${expdir}/results/$tset
68
+
69
+ mkdir -p $RESULTS_DIR
70
+
71
+ bash joint_translate.sh $SRC_FILE $RESULTS_DIR/${src_lang}-${tgt_lang} $src_lang $tgt_lang $expdir $REF_FILE
72
+ # for newline between different outputs
73
+ echo
74
+ done
legacy/tpu_training_instructions.md ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Instructions to run on Google cloud TPUs
2
+ Before starting these steps, make sure to prepare the dataset (normalization -> bpe -> .. -> binarization) following the steps in indicTrans workflow or do these steps on a cpu instance before launching the tpu instance (to save time and costs)
3
+
4
+ ### Creating TPU instance
5
+
6
+ - Create a cpu instance on gcp with `torch-xla` image like:
7
+ ```bash
8
+ gcloud compute --project=${PROJECT_ID} instances create <name for your instance> \
9
+ --zone=<zone> \
10
+ --machine-type=n1-standard-16 \
11
+ --image-family=torch-xla \
12
+ --image-project=ml-images \
13
+ --boot-disk-size=200GB \
14
+ --scopes=https://www.googleapis.com/auth/cloud-platform
15
+ ```
16
+ - Once the instance is created, Launch a Cloud TPU (from your cpu vm instance) using the following command (you can change the `accelerator_type` according to your needs):
17
+ ```bash
18
+ gcloud compute tpus create <name for your TPU> \
19
+ --zone=<zone> \
20
+ --network=default \
21
+ --version=pytorch-1.7 \
22
+ --accelerator-type=v3-8
23
+ ```
24
+ (or)
25
+ Create a new tpu using the GUI in https://console.cloud.google.com/compute/tpus and make sure to select `version` as `pytorch 1.7`.
26
+
27
+ - Once the tpu is launched, identify its ip address:
28
+ ```bash
29
+ # you can run this inside cpu instance and note down the IP address which is located under the NETWORK_ENDPOINTS column
30
+ gcloud compute tpus list --zone=us-central1-a
31
+ ```
32
+ (or)
33
+ Go to https://console.cloud.google.com/compute/tpus and note down ip address for the created TPU from the `interal ip` column
34
+
35
+ ### Installing Fairseq, getting data on the cpu instance
36
+
37
+ - Activate the `torch xla 1.7` conda environment and install necessary libs for IndicTrans (**Excluding FairSeq**):
38
+ ```bash
39
+ conda activate torch-xla-1.7
40
+ pip install sacremoses pandas mock sacrebleu tensorboardX pyarrow
41
+ ```
42
+ - Configure environment variables for TPU:
43
+ ```bash
44
+ export TPU_IP_ADDRESS=ip-address; \
45
+ export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
46
+ ```
47
+ - Download the prepared binarized data for FairSeq
48
+
49
+ - Clone the latest version of Fairseq (this supports tpu) and install from source. There is an [issue](https://github.com/pytorch/fairseq/issues/3259) with the latest commit and hence we use a different commit to install from source (This may have been fixed in the latest master but we have not tested it.)
50
+ ```bash
51
+ git clone https://github.com/pytorch/fairseq.git
52
+ git checkout da9eaba12d82b9bfc1442f0e2c6fc1b895f4d35d
53
+ pip install --editable ./
54
+ ```
55
+
56
+ - Start TPU training
57
+ ```bash
58
+ # this is for using all tpu cores
59
+ export MKL_SERVICE_FORCE_INTEL=1
60
+
61
+ fairseq-train {expdir}/exp2_m2o_baseline/final_bin \
62
+ --max-source-positions=200 \
63
+ --max-target-positions=200 \
64
+ --max-update=1000000 \
65
+ --save-interval=5 \
66
+ --arch=transformer \
67
+ --attention-dropout=0.1 \
68
+ --criterion=label_smoothed_cross_entropy \
69
+ --source-lang=SRC \
70
+ --lr-scheduler=inverse_sqrt \
71
+ --skip-invalid-size-inputs-valid-test \
72
+ --target-lang=TGT \
73
+ --label-smoothing=0.1 \
74
+ --update-freq=1 \
75
+ --optimizer adam \
76
+ --adam-betas '(0.9, 0.98)' \
77
+ --warmup-init-lr 1e-07 \
78
+ --lr 0.0005 \
79
+ --warmup-updates 4000 \
80
+ --dropout 0.2 \
81
+ --weight-decay 0.0 \
82
+ --tpu \
83
+ --distributed-world-size 8 \
84
+ --max-tokens 8192 \
85
+ --num-batch-buckets 8 \
86
+ --tensorboard-logdir {expdir}/exp2_m2o_baseline/tensorboard \
87
+ --save-dir {expdir}/exp2_m2o_baseline/model \
88
+ --keep-last-epochs 5 \
89
+ --patience 5
90
+ ```
91
+
92
+ **Note** While training, we noticed that the training was slower on tpus, compared to using multiple GPUs, we have documented some issues and [filed an issue](https://github.com/pytorch/fairseq/issues/3317) at fairseq repo for advice. We'll update this section as we learn more about efficient training on TPUs. Also feel free to open an issue/pull request if you find a bug or know an efficient method to make code train faster on tpus.
legacy/translate.sh ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ echo `date`
3
+ infname=$1
4
+ outfname=$2
5
+ src_lang=$3
6
+ tgt_lang=$4
7
+ exp_dir=$5
8
+ ref_fname=$6
9
+
10
+ if [ $src_lang == 'en' ]; then
11
+ SRC_PREFIX='TGT'
12
+ TGT_PREFIX='SRC'
13
+ else
14
+ SRC_PREFIX='SRC'
15
+ TGT_PREFIX='TGT'
16
+ fi
17
+
18
+ #`dirname $0`/env.sh
19
+ SUBWORD_NMT_DIR='subword-nmt'
20
+ model_dir=$exp_dir/model
21
+ data_bin_dir=$exp_dir/final_bin
22
+
23
+ ### normalization and script conversion
24
+
25
+ echo "Applying normalization and script conversion"
26
+ input_size=`python preprocess_translate.py $infname $outfname.norm $src_lang`
27
+ echo "Number of sentences in input: $input_size"
28
+
29
+ ### apply BPE to input file
30
+
31
+ echo "Applying BPE"
32
+ python $SUBWORD_NMT_DIR/subword_nmt/apply_bpe.py \
33
+ -c $exp_dir/vocab/bpe_codes.32k.${SRC_PREFIX}_${TGT_PREFIX} \
34
+ --vocabulary $exp_dir/vocab/vocab.$SRC_PREFIX \
35
+ --vocabulary-threshold 5 \
36
+ < $outfname.norm \
37
+ > $outfname.bpe
38
+
39
+ # not needed for joint training
40
+ # echo "Adding language tags"
41
+ # python add_tags_translate.py $outfname._bpe $outfname.bpe $src_lang $tgt_lang
42
+
43
+ ### run decoder
44
+
45
+ echo "Decoding"
46
+
47
+ src_input_bpe_fname=$outfname.bpe
48
+ tgt_output_fname=$outfname
49
+ fairseq-interactive $data_bin_dir \
50
+ -s $SRC_PREFIX -t $TGT_PREFIX \
51
+ --distributed-world-size 1 \
52
+ --path $model_dir/checkpoint_best.pt \
53
+ --batch-size 64 --buffer-size 2500 --beam 5 --remove-bpe \
54
+ --skip-invalid-size-inputs-valid-test \
55
+ --input $src_input_bpe_fname > $tgt_output_fname.log 2>&1
56
+
57
+
58
+ echo "Extracting translations, script conversion and detokenization"
59
+ python postprocess_translate.py $tgt_output_fname.log $tgt_output_fname $input_size $tgt_lang
60
+ if [ $src_lang == 'en' ]; then
61
+ # indicnlp tokenize the output files before evaluation
62
+ input_size=`python preprocess_translate.py $ref_fname $ref_fname.tok $tgt_lang`
63
+ input_size=`python preprocess_translate.py $tgt_output_fname $tgt_output_fname.tok $tgt_lang`
64
+ sacrebleu --tokenize none $ref_fname.tok < $tgt_output_fname.tok
65
+ else
66
+ # indic to en models
67
+ sacrebleu $ref_fname < $tgt_output_fname
68
+ fi
69
+ echo `date`
70
+ echo "Translation completed"
model_configs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import custom_transformer
model_configs/custom_transformer.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fairseq.models import register_model_architecture
2
+ from fairseq.models.transformer import base_architecture
3
+
4
+
5
+ @register_model_architecture("transformer", "transformer_2x")
6
+ def transformer_big(args):
7
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
8
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
9
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
10
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
11
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
12
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
13
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
14
+ base_architecture(args)
15
+
16
+
17
+ @register_model_architecture("transformer", "transformer_4x")
18
+ def transformer_huge(args):
19
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1536)
20
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
21
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
22
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
23
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1536)
24
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
25
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
26
+ base_architecture(args)
27
+
28
+
29
+ @register_model_architecture("transformer", "transformer_9x")
30
+ def transformer_xlarge(args):
31
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 2048)
32
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 8192)
33
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
34
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
35
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048)
36
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192)
37
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
38
+ base_architecture(args)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sacremoses
2
+ pandas
3
+ mock
4
+ sacrebleu
5
+ pyarrow
6
+ indic-nlp-library
7
+ mosestokenizer
8
+ subword-nmt
9
+ numpy
10
+ tensorboardX
11
+ git+https://github.com/pytorch/fairseq.git
scripts/__init__.py ADDED
File without changes
scripts/add_joint_tags_translate.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from tqdm import tqdm
3
+ import os
4
+
5
+
6
+ def add_token(sent, tag_infos):
7
+ """ add special tokens specified by tag_infos to each element in list
8
+
9
+ tag_infos: list of tuples (tag_type,tag)
10
+
11
+ each tag_info results in a token of the form: __{tag_type}__{tag}__
12
+
13
+ """
14
+
15
+ tokens = []
16
+ for tag_type, tag in tag_infos:
17
+ token = '__' + tag_type + '__' + tag + '__'
18
+ tokens.append(token)
19
+
20
+ return ' '.join(tokens) + ' ' + sent
21
+
22
+
23
+ def generate_lang_tag_iterator(infname):
24
+ with open(infname, 'r', encoding='utf-8') as infile:
25
+ for line in infile:
26
+ src, tgt, count = line.strip().split('\t')
27
+ count = int(count)
28
+ for _ in range(count):
29
+ yield (src, tgt)
30
+
31
+
32
+ if __name__ == '__main__':
33
+
34
+ expdir = sys.argv[1]
35
+ dset = sys.argv[2]
36
+
37
+ src_fname = '{expdir}/bpe/{dset}.SRC'.format(
38
+ expdir=expdir, dset=dset)
39
+ tgt_fname = '{expdir}/bpe/{dset}.TGT'.format(
40
+ expdir=expdir, dset=dset)
41
+ meta_fname = '{expdir}/data/{dset}_lang_pairs.txt'.format(
42
+ expdir=expdir, dset=dset)
43
+
44
+ out_src_fname = '{expdir}/final/{dset}.SRC'.format(
45
+ expdir=expdir, dset=dset)
46
+ out_tgt_fname = '{expdir}/final/{dset}.TGT'.format(
47
+ expdir=expdir, dset=dset)
48
+ lang_tag_iterator = generate_lang_tag_iterator(meta_fname)
49
+
50
+ os.makedirs('{expdir}/final'.format(expdir=expdir), exist_ok=True)
51
+
52
+ with open(src_fname, 'r', encoding='utf-8') as srcfile, \
53
+ open(tgt_fname, 'r', encoding='utf-8') as tgtfile, \
54
+ open(out_src_fname, 'w', encoding='utf-8') as outsrcfile, \
55
+ open(out_tgt_fname, 'w', encoding='utf-8') as outtgtfile:
56
+
57
+ for (l1, l2), src_sent, tgt_sent in tqdm(zip(lang_tag_iterator,
58
+ srcfile, tgtfile)):
59
+ outsrcfile.write(add_token(src_sent.strip(), [
60
+ ('src', l1), ('tgt', l2)]) + '\n')
61
+ outtgtfile.write(tgt_sent.strip() + '\n')
scripts/add_tags_translate.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+
4
+ def add_token(sent, tag_infos):
5
+ """ add special tokens specified by tag_infos to each element in list
6
+
7
+ tag_infos: list of tuples (tag_type,tag)
8
+
9
+ each tag_info results in a token of the form: __{tag_type}__{tag}__
10
+
11
+ """
12
+
13
+ tokens = []
14
+ for tag_type, tag in tag_infos:
15
+ token = '__' + tag_type + '__' + tag + '__'
16
+ tokens.append(token)
17
+
18
+ return ' '.join(tokens) + ' ' + sent
19
+
20
+
21
+ if __name__ == '__main__':
22
+
23
+ infname = sys.argv[1]
24
+ outfname = sys.argv[2]
25
+ src_lang = sys.argv[3]
26
+ tgt_lang = sys.argv[4]
27
+
28
+ with open(infname, 'r', encoding='utf-8') as infile, \
29
+ open(outfname, 'w', encoding='utf-8') as outfile:
30
+ for line in infile:
31
+ outstr = add_token(
32
+ line.strip(), [('src', src_lang), ('tgt', tgt_lang)])
33
+ outfile.write(outstr + '\n')
scripts/clean_vocab.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import codecs
3
+
4
+ def clean_vocab(in_vocab_fname, out_vocab_fname):
5
+ with codecs.open(in_vocab_fname, "r", encoding="utf-8") as infile, codecs.open(
6
+ out_vocab_fname, "w", encoding="utf-8"
7
+ ) as outfile:
8
+ for i, line in enumerate(infile):
9
+ fields = line.strip("\r\n ").split(" ")
10
+ if len(fields) == 2:
11
+ outfile.write(line)
12
+ if len(fields) != 2:
13
+ print("{}: {}".format(i, line.strip()))
14
+ for c in line:
15
+ print("{}:{}".format(c, hex(ord(c))))
16
+
17
+
18
+ if __name__ == "__main__":
19
+ clean_vocab(sys.argv[1], sys.argv[2])
scripts/concat_joint_data.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ import sys
4
+
5
+ LANGS = [
6
+ "as",
7
+ "bn",
8
+ "gu",
9
+ "hi",
10
+ "kn",
11
+ "ml",
12
+ "mr",
13
+ "or",
14
+ "pa",
15
+ "ta",
16
+ "te",
17
+ #"ur"
18
+ ]
19
+
20
+
21
+ def add_token(sent, tag_infos):
22
+ """ add special tokens specified by tag_infos to each element in list
23
+
24
+ tag_infos: list of tuples (tag_type,tag)
25
+
26
+ each tag_info results in a token of the form: __{tag_type}__{tag}__
27
+
28
+ """
29
+
30
+ tokens = []
31
+ for tag_type, tag in tag_infos:
32
+ token = '__' + tag_type + '__' + tag + '__'
33
+ tokens.append(token)
34
+
35
+ return ' '.join(tokens) + ' ' + sent
36
+
37
+
38
+ def concat_data(data_dir, outdir, lang_pair_list,
39
+ out_src_lang='SRC', out_trg_lang='TGT', split='train'):
40
+ """
41
+ data_dir: input dir, contains directories for language pairs named l1-l2
42
+ """
43
+ os.makedirs(outdir, exist_ok=True)
44
+
45
+ out_src_fname = '{}/{}.{}'.format(outdir, split, out_src_lang)
46
+ out_trg_fname = '{}/{}.{}'.format(outdir, split, out_trg_lang)
47
+ # out_meta_fname='{}/metadata.txt'.format(outdir)
48
+
49
+ print()
50
+ print(out_src_fname)
51
+ print(out_trg_fname)
52
+ # print(out_meta_fname)
53
+
54
+ # concatenate train data
55
+ if os.path.isfile(out_src_fname):
56
+ os.unlink(out_src_fname)
57
+ if os.path.isfile(out_trg_fname):
58
+ os.unlink(out_trg_fname)
59
+ # if os.path.isfile(out_meta_fname):
60
+ # os.unlink(out_meta_fname)
61
+
62
+ for src_lang, trg_lang in tqdm(lang_pair_list):
63
+ print('src: {}, tgt:{}'.format(src_lang, trg_lang))
64
+
65
+ in_src_fname = '{}/{}-{}/{}.{}'.format(
66
+ data_dir, src_lang, trg_lang, split, src_lang)
67
+ in_trg_fname = '{}/{}-{}/{}.{}'.format(
68
+ data_dir, src_lang, trg_lang, split, trg_lang)
69
+
70
+ if not os.path.exists(in_src_fname):
71
+ continue
72
+ if not os.path.exists(in_trg_fname):
73
+ continue
74
+
75
+ print(in_src_fname)
76
+ os.system('cat {} >> {}'.format(in_src_fname, out_src_fname))
77
+
78
+ print(in_trg_fname)
79
+ os.system('cat {} >> {}'.format(in_trg_fname, out_trg_fname))
80
+
81
+
82
+ # with open('{}/lang_pairs.txt'.format(outdir),'w',encoding='utf-8') as lpfile:
83
+ # lpfile.write('\n'.join( [ '-'.join(x) for x in lang_pair_list ] ))
84
+
85
+ corpus_stats(data_dir, outdir, lang_pair_list, split)
86
+
87
+
88
+ def corpus_stats(data_dir, outdir, lang_pair_list, split):
89
+ """
90
+ data_dir: input dir, contains directories for language pairs named l1-l2
91
+ """
92
+
93
+ with open('{}/{}_lang_pairs.txt'.format(outdir, split), 'w', encoding='utf-8') as lpfile:
94
+
95
+ for src_lang, trg_lang in tqdm(lang_pair_list):
96
+ print('src: {}, tgt:{}'.format(src_lang, trg_lang))
97
+
98
+ in_src_fname = '{}/{}-{}/{}.{}'.format(
99
+ data_dir, src_lang, trg_lang, split, src_lang)
100
+ # in_trg_fname='{}/{}-{}/train.{}'.format(data_dir,src_lang,trg_lang,trg_lang)
101
+ if not os.path.exists(in_src_fname):
102
+ continue
103
+
104
+ print(in_src_fname)
105
+ corpus_size = 0
106
+ with open(in_src_fname, 'r', encoding='utf-8') as infile:
107
+ corpus_size = sum(map(lambda x: 1, infile))
108
+
109
+ lpfile.write('{}\t{}\t{}\n'.format(
110
+ src_lang, trg_lang, corpus_size))
111
+
112
+
113
+ if __name__ == '__main__':
114
+
115
+ in_dir = sys.argv[1]
116
+ out_dir = sys.argv[2]
117
+ src_lang = sys.argv[3]
118
+ tgt_lang = sys.argv[4]
119
+ split = sys.argv[5]
120
+ lang_pair_list = []
121
+
122
+ if src_lang == 'en':
123
+ for lang in LANGS:
124
+ lang_pair_list.append(['en', lang])
125
+ else:
126
+ for lang in LANGS:
127
+ lang_pair_list.append([lang, 'en'])
128
+
129
+ concat_data(in_dir, out_dir, lang_pair_list, split=split)
130
+
scripts/extract_non_english_pairs.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import os
3
+ from collections import defaultdict
4
+
5
+
6
+ def read_file(fname):
7
+ with open(fname, "r", encoding="utf-8") as infile:
8
+ for line in infile:
9
+ yield line.strip()
10
+
11
+
12
+ def extract_non_english_pairs(indir, outdir, LANGS):
13
+ """
14
+ Extracts non-english pair parallel corpora
15
+
16
+ indir: contains english centric data in the following form:
17
+ - directory named en-xx for language xx
18
+ - each directory contains a train.en and train.xx
19
+ outdir: output directory to store mined data for each pair.
20
+ One directory is created for each pair.
21
+ LANGS: list of languages in the corpus (other than English).
22
+ The language codes must correspond to the ones used in the
23
+ files and directories in indir. Prefarably, sort the languages
24
+ in this list in alphabetic order. outdir will contain data for xx-yy,
25
+ but not for yy-xx, so it will be convenient to have this list in sorted order.
26
+ """
27
+
28
+ for i in tqdm(range(len(LANGS) - 1)):
29
+ print()
30
+ for j in range(i + 1, len(LANGS)):
31
+ lang1 = LANGS[i]
32
+ lang2 = LANGS[j]
33
+ # print()
34
+ print("{} {}".format(lang1, lang2))
35
+
36
+ fname1 = "{}/en-{}/train.en".format(indir, lang1)
37
+ fname2 = "{}/en-{}/train.en".format(indir, lang2)
38
+ # print(fname1)
39
+ # print(fname2)
40
+ enset_l1 = set(read_file(fname1))
41
+ common_en_set = enset_l1.intersection(read_file(fname2))
42
+
43
+ ## this block should be used if you want to consider multiple translations.
44
+ # il_fname1 = "{}/en-{}/train.{}".format(indir, lang1, lang1)
45
+ # en_lang1_dict = defaultdict(list)
46
+ # for en_line, il_line in zip(read_file(fname1), read_file(il_fname1)):
47
+ # if en_line in common_en_set:
48
+ # en_lang1_dict[en_line].append(il_line)
49
+
50
+ # # this block should be used if you DONT to consider multiple translation.
51
+ il_fname1='{}/en-{}/train.{}'.format(indir,lang1,lang1)
52
+ en_lang1_dict={}
53
+ for en_line,il_line in zip(read_file(fname1),read_file(il_fname1)):
54
+ if en_line in common_en_set:
55
+ en_lang1_dict[en_line]=il_line
56
+
57
+ os.makedirs("{}/{}-{}".format(outdir, lang1, lang2), exist_ok=True)
58
+ out_l1_fname = "{o}/{l1}-{l2}/train.{l1}".format(
59
+ o=outdir, l1=lang1, l2=lang2
60
+ )
61
+ out_l2_fname = "{o}/{l1}-{l2}/train.{l2}".format(
62
+ o=outdir, l1=lang1, l2=lang2
63
+ )
64
+
65
+ il_fname2 = "{}/en-{}/train.{}".format(indir, lang2, lang2)
66
+ with open(out_l1_fname, "w", encoding="utf-8") as out_l1_file, open(
67
+ out_l2_fname, "w", encoding="utf-8"
68
+ ) as out_l2_file:
69
+ for en_line, il_line in zip(read_file(fname2), read_file(il_fname2)):
70
+ if en_line in en_lang1_dict:
71
+
72
+ # this block should be used if you want to consider multiple tranlations.
73
+ for il_line_lang1 in en_lang1_dict[en_line]:
74
+ # lang1_line, lang2_line = il_line_lang1, il_line
75
+ # out_l1_file.write(lang1_line + "\n")
76
+ # out_l2_file.write(lang2_line + "\n")
77
+
78
+ # this block should be used if you DONT to consider multiple translation.
79
+ lang1_line, lang2_line = en_lang1_dict[en_line], il_line
80
+ out_l1_file.write(lang1_line+'\n')
81
+ out_l2_file.write(lang2_line+'\n')
82
+
83
+
84
+ def get_extracted_stats(outdir, LANGS):
85
+ """
86
+ gathers stats from the extracted directories
87
+
88
+ outdir: output directory to store mined data for each pair.
89
+ One directory is created for each pair.
90
+ LANGS: list of languages in the corpus (other than languages).
91
+ The language codes must correspond to the ones used in the
92
+ files and directories in indir. Prefarably, sort the languages
93
+ in this list in alphabetic order. outdir will contain data for xx-yy,
94
+ """
95
+ common_stats = []
96
+ for i in tqdm(range(len(LANGS) - 1)):
97
+ for j in range(i + 1, len(LANGS)):
98
+ lang1 = LANGS[i]
99
+ lang2 = LANGS[j]
100
+
101
+ out_l1_fname = "{o}/{l1}-{l2}/train.{l1}".format(
102
+ o=outdir, l1=lang1, l2=lang2
103
+ )
104
+
105
+ cnt = sum([1 for _ in read_file(out_l1_fname)])
106
+ common_stats.append((lang1, lang2, cnt))
107
+ common_stats.append((lang2, lang1, cnt))
108
+ return common_stats
scripts/postprocess_score.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ def postprocess(
4
+ infname, outfname, input_size
5
+ ):
6
+ """
7
+ parse fairseq interactive output, convert script back to native Indic script (in case of Indic languages) and detokenize.
8
+
9
+ infname: fairseq log file
10
+ outfname: output file of translation (sentences not translated contain the dummy string 'DUMMY_OUTPUT'
11
+ input_size: expected number of output sentences
12
+ """
13
+
14
+ consolidated_testoutput = []
15
+ # with open(infname,'r',encoding='utf-8') as infile:
16
+ # consolidated_testoutput= list(map(lambda x: x.strip(), filter(lambda x: x.startswith('H-'),infile) ))
17
+ # consolidated_testoutput.sort(key=lambda x: int(x.split('\t')[0].split('-')[1]))
18
+ # consolidated_testoutput=[ x.split('\t')[2] for x in consolidated_testoutput ]
19
+
20
+ consolidated_testoutput = [(x, 0.0, "") for x in range(input_size)]
21
+ temp_testoutput = []
22
+ with open(infname, "r", encoding="utf-8") as infile:
23
+ temp_testoutput = list(
24
+ map(
25
+ lambda x: x.strip().split("\t"),
26
+ filter(lambda x: x.startswith("H-"), infile),
27
+ )
28
+ )
29
+ temp_testoutput = list(
30
+ map(lambda x: (int(x[0].split("-")[1]), float(x[1]), x[2]), temp_testoutput)
31
+ )
32
+ for sid, score, hyp in temp_testoutput:
33
+ consolidated_testoutput[sid] = (sid, score, hyp)
34
+ #consolidated_testoutput = [x[2] for x in consolidated_testoutput]
35
+
36
+ with open(outfname, "w", encoding="utf-8") as outfile:
37
+ for (sid, score, hyp) in consolidated_testoutput:
38
+ outfile.write("{}\n".format(score))
39
+
40
+ if __name__ == "__main__":
41
+
42
+ infname = sys.argv[1]
43
+ outfname = sys.argv[2]
44
+ input_size = int(sys.argv[3])
45
+
46
+ postprocess(
47
+ infname, outfname, input_size
48
+ )
scripts/postprocess_translate.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ INDIC_NLP_LIB_HOME = "indic_nlp_library"
2
+ INDIC_NLP_RESOURCES = "indic_nlp_resources"
3
+ import sys
4
+
5
+ from indicnlp import transliterate
6
+
7
+ sys.path.append(r"{}".format(INDIC_NLP_LIB_HOME))
8
+ from indicnlp import common
9
+
10
+ common.set_resources_path(INDIC_NLP_RESOURCES)
11
+ from indicnlp import loader
12
+
13
+ loader.load()
14
+ from sacremoses import MosesPunctNormalizer
15
+ from sacremoses import MosesTokenizer
16
+ from sacremoses import MosesDetokenizer
17
+ from collections import defaultdict
18
+
19
+ import indicnlp
20
+ from indicnlp.tokenize import indic_tokenize
21
+ from indicnlp.tokenize import indic_detokenize
22
+ from indicnlp.normalize import indic_normalize
23
+ from indicnlp.transliterate import unicode_transliterate
24
+
25
+
26
+ def postprocess(
27
+ infname, outfname, input_size, lang, common_lang="hi", transliterate=False
28
+ ):
29
+ """
30
+ parse fairseq interactive output, convert script back to native Indic script (in case of Indic languages) and detokenize.
31
+
32
+ infname: fairseq log file
33
+ outfname: output file of translation (sentences not translated contain the dummy string 'DUMMY_OUTPUT'
34
+ input_size: expected number of output sentences
35
+ lang: language
36
+ """
37
+
38
+ consolidated_testoutput = []
39
+ # with open(infname,'r',encoding='utf-8') as infile:
40
+ # consolidated_testoutput= list(map(lambda x: x.strip(), filter(lambda x: x.startswith('H-'),infile) ))
41
+ # consolidated_testoutput.sort(key=lambda x: int(x.split('\t')[0].split('-')[1]))
42
+ # consolidated_testoutput=[ x.split('\t')[2] for x in consolidated_testoutput ]
43
+
44
+ consolidated_testoutput = [(x, 0.0, "") for x in range(input_size)]
45
+ temp_testoutput = []
46
+ with open(infname, "r", encoding="utf-8") as infile:
47
+ temp_testoutput = list(
48
+ map(
49
+ lambda x: x.strip().split("\t"),
50
+ filter(lambda x: x.startswith("H-"), infile),
51
+ )
52
+ )
53
+ temp_testoutput = list(
54
+ map(lambda x: (int(x[0].split("-")[1]), float(x[1]), x[2]), temp_testoutput)
55
+ )
56
+ for sid, score, hyp in temp_testoutput:
57
+ consolidated_testoutput[sid] = (sid, score, hyp)
58
+ consolidated_testoutput = [x[2] for x in consolidated_testoutput]
59
+
60
+ if lang == "en":
61
+ en_detok = MosesDetokenizer(lang="en")
62
+ with open(outfname, "w", encoding="utf-8") as outfile:
63
+ for sent in consolidated_testoutput:
64
+ outfile.write(en_detok.detokenize(sent.split(" ")) + "\n")
65
+ else:
66
+ xliterator = unicode_transliterate.UnicodeIndicTransliterator()
67
+ with open(outfname, "w", encoding="utf-8") as outfile:
68
+ for sent in consolidated_testoutput:
69
+ if transliterate:
70
+ outstr = indic_detokenize.trivial_detokenize(
71
+ xliterator.transliterate(sent, common_lang, lang), lang
72
+ )
73
+ else:
74
+ outstr = indic_detokenize.trivial_detokenize(sent, lang)
75
+ outfile.write(outstr + "\n")
76
+
77
+
78
+ if __name__ == "__main__":
79
+ # # The path to the local git repo for Indic NLP library
80
+ # INDIC_NLP_LIB_HOME="indic_nlp_library"
81
+ # INDIC_NLP_RESOURCES = "indic_nlp_resources"
82
+ # sys.path.append('{}'.format(INDIC_NLP_LIB_HOME))
83
+ # common.set_resources_path(INDIC_NLP_RESOURCES)
84
+ # # The path to the local git repo for Indic NLP Resources
85
+ # INDIC_NLP_RESOURCES=""
86
+
87
+ # sys.path.append('{}'.format(INDIC_NLP_LIB_HOME))
88
+ # common.set_resources_path(INDIC_NLP_RESOURCES)
89
+
90
+ # loader.load()
91
+
92
+ infname = sys.argv[1]
93
+ outfname = sys.argv[2]
94
+ input_size = int(sys.argv[3])
95
+ lang = sys.argv[4]
96
+ if len(sys.argv) == 5:
97
+ transliterate = False
98
+ elif len(sys.argv) == 6:
99
+ transliterate = sys.argv[5]
100
+ if transliterate.lower() == "true":
101
+ transliterate = True
102
+ else:
103
+ transliterate = False
104
+ else:
105
+ print(f"Invalid arguments: {sys.argv}")
106
+ exit()
107
+
108
+ postprocess(
109
+ infname, outfname, input_size, lang, common_lang="hi", transliterate=transliterate
110
+ )
scripts/preprocess_translate.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ INDIC_NLP_LIB_HOME = "indic_nlp_library"
2
+ INDIC_NLP_RESOURCES = "indic_nlp_resources"
3
+ import sys
4
+
5
+ sys.path.append(r"{}".format(INDIC_NLP_LIB_HOME))
6
+ from indicnlp import common
7
+
8
+ common.set_resources_path(INDIC_NLP_RESOURCES)
9
+ from indicnlp import loader
10
+
11
+ loader.load()
12
+ from sacremoses import MosesPunctNormalizer
13
+ from sacremoses import MosesTokenizer
14
+ from sacremoses import MosesDetokenizer
15
+ from collections import defaultdict
16
+
17
+ from tqdm import tqdm
18
+ from joblib import Parallel, delayed
19
+
20
+ from indicnlp.tokenize import indic_tokenize
21
+ from indicnlp.tokenize import indic_detokenize
22
+ from indicnlp.normalize import indic_normalize
23
+ from indicnlp.transliterate import unicode_transliterate
24
+
25
+
26
+ en_tok = MosesTokenizer(lang="en")
27
+ en_normalizer = MosesPunctNormalizer()
28
+
29
+
30
+ def preprocess_line(line, normalizer, lang, transliterate=False):
31
+ if lang == "en":
32
+ return " ".join(
33
+ en_tok.tokenize(en_normalizer.normalize(line.strip()), escape=False)
34
+ )
35
+ elif transliterate:
36
+ # line = indic_detokenize.trivial_detokenize(line.strip(), lang)
37
+ return unicode_transliterate.UnicodeIndicTransliterator.transliterate(
38
+ " ".join(
39
+ indic_tokenize.trivial_tokenize(
40
+ normalizer.normalize(line.strip()), lang
41
+ )
42
+ ),
43
+ lang,
44
+ "hi",
45
+ ).replace(" ् ", "्")
46
+ else:
47
+ # we only need to transliterate for joint training
48
+ return " ".join(
49
+ indic_tokenize.trivial_tokenize(normalizer.normalize(line.strip()), lang)
50
+ )
51
+
52
+
53
+ def preprocess(infname, outfname, lang, transliterate=False):
54
+ """
55
+ Normalize, tokenize and script convert(for Indic)
56
+ return number of sentences input file
57
+
58
+ """
59
+
60
+ n = 0
61
+ num_lines = sum(1 for line in open(infname, "r"))
62
+ if lang == "en":
63
+ with open(infname, "r", encoding="utf-8") as infile, open(
64
+ outfname, "w", encoding="utf-8"
65
+ ) as outfile:
66
+
67
+ out_lines = Parallel(n_jobs=-1, backend="multiprocessing")(
68
+ delayed(preprocess_line)(line, None, lang)
69
+ for line in tqdm(infile, total=num_lines)
70
+ )
71
+
72
+ for line in out_lines:
73
+ outfile.write(line + "\n")
74
+ n += 1
75
+
76
+ else:
77
+ normfactory = indic_normalize.IndicNormalizerFactory()
78
+ normalizer = normfactory.get_normalizer(lang)
79
+ # reading
80
+ with open(infname, "r", encoding="utf-8") as infile, open(
81
+ outfname, "w", encoding="utf-8"
82
+ ) as outfile:
83
+
84
+ out_lines = Parallel(n_jobs=-1, backend="multiprocessing")(
85
+ delayed(preprocess_line)(line, normalizer, lang, transliterate)
86
+ for line in tqdm(infile, total=num_lines)
87
+ )
88
+
89
+ for line in out_lines:
90
+ outfile.write(line + "\n")
91
+ n += 1
92
+ return n
93
+
94
+
95
+ def old_preprocess(infname, outfname, lang):
96
+ """
97
+ Preparing each corpus file:
98
+ - Normalization
99
+ - Tokenization
100
+ - Script coversion to Devanagari for Indic scripts
101
+ """
102
+ n = 0
103
+ num_lines = sum(1 for line in open(infname, "r"))
104
+ # reading
105
+ with open(infname, "r", encoding="utf-8") as infile, open(
106
+ outfname, "w", encoding="utf-8"
107
+ ) as outfile:
108
+
109
+ if lang == "en":
110
+ en_tok = MosesTokenizer(lang="en")
111
+ en_normalizer = MosesPunctNormalizer()
112
+ for line in tqdm(infile, total=num_lines):
113
+ outline = " ".join(
114
+ en_tok.tokenize(en_normalizer.normalize(line.strip()), escape=False)
115
+ )
116
+ outfile.write(outline + "\n")
117
+ n += 1
118
+
119
+ else:
120
+ normfactory = indic_normalize.IndicNormalizerFactory()
121
+ normalizer = normfactory.get_normalizer(lang)
122
+ for line in tqdm(infile, total=num_lines):
123
+ outline = (
124
+ unicode_transliterate.UnicodeIndicTransliterator.transliterate(
125
+ " ".join(
126
+ indic_tokenize.trivial_tokenize(
127
+ normalizer.normalize(line.strip()), lang
128
+ )
129
+ ),
130
+ lang,
131
+ "hi",
132
+ ).replace(" ् ", "्")
133
+ )
134
+
135
+ outfile.write(outline + "\n")
136
+ n += 1
137
+ return n
138
+
139
+
140
+ if __name__ == "__main__":
141
+
142
+ # INDIC_NLP_LIB_HOME = "indic_nlp_library"
143
+ # INDIC_NLP_RESOURCES = "indic_nlp_resources"
144
+ # sys.path.append(r'{}'.format(INDIC_NLP_LIB_HOME))
145
+ # common.set_resources_path(INDIC_NLP_RESOURCES)
146
+
147
+ # data_dir = '../joint_training/v1'
148
+ # new_dir = data_dir + '.norm'
149
+ # for path, subdirs, files in os.walk(data_dir):
150
+ # for name in files:
151
+ # infile = os.path.join(path, name)
152
+ # lang = infile.split('.')[-1]
153
+ # outfile = os.path.join(path.replace(data_dir, new_dir), name)
154
+ # preprocess(infile, outfile, lang)
155
+ # loader.load()
156
+
157
+ infname = sys.argv[1]
158
+ outfname = sys.argv[2]
159
+ lang = sys.argv[3]
160
+
161
+ if len(sys.argv) == 4:
162
+ transliterate = False
163
+ elif len(sys.argv) == 5:
164
+ transliterate = sys.argv[4]
165
+ if transliterate.lower() == "true":
166
+ transliterate = True
167
+ else:
168
+ transliterate = False
169
+ else:
170
+ print(f"Invalid arguments: {sys.argv}")
171
+ exit()
172
+ print(preprocess(infname, outfname, lang, transliterate))
scripts/remove_large_sentences.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import sys
3
+
4
+
5
+ def remove_large_sentences(src_path, tgt_path):
6
+ count = 0
7
+ new_src_lines = []
8
+ new_tgt_lines = []
9
+ src_num_lines = sum(1 for line in open(src_path, "r", encoding="utf-8"))
10
+ tgt_num_lines = sum(1 for line in open(tgt_path, "r", encoding="utf-8"))
11
+ assert src_num_lines == tgt_num_lines
12
+ with open(src_path, encoding="utf-8") as f1, open(tgt_path, encoding="utf-8") as f2:
13
+ for src_line, tgt_line in tqdm(zip(f1, f2), total=src_num_lines):
14
+ src_tokens = src_line.strip().split(" ")
15
+ tgt_tokens = tgt_line.strip().split(" ")
16
+ if len(src_tokens) > 200 or len(tgt_tokens) > 200:
17
+ count += 1
18
+ continue
19
+ new_src_lines.append(src_line)
20
+ new_tgt_lines.append(tgt_line)
21
+ return count, new_src_lines, new_tgt_lines
22
+
23
+
24
+ def create_txt(outFile, lines, add_newline=False):
25
+ outfile = open("{0}".format(outFile), "w", encoding="utf-8")
26
+ for line in lines:
27
+ if add_newline:
28
+ outfile.write(line + "\n")
29
+ else:
30
+ outfile.write(line)
31
+ outfile.close()
32
+
33
+
34
+ if __name__ == "__main__":
35
+
36
+ src_path = sys.argv[1]
37
+ tgt_path = sys.argv[2]
38
+ new_src_path = sys.argv[3]
39
+ new_tgt_path = sys.argv[4]
40
+
41
+ count, new_src_lines, new_tgt_lines = remove_large_sentences(src_path, tgt_path)
42
+ print(f'{count} lines removed due to seq_len > 200')
43
+ create_txt(new_src_path, new_src_lines)
44
+ create_txt(new_tgt_path, new_tgt_lines)
scripts/remove_train_devtest_overlaps.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import string
3
+ import shutil
4
+ from itertools import permutations, chain
5
+ from collections import defaultdict
6
+ from tqdm import tqdm
7
+ import sys
8
+
9
+ INDIC_LANGS = ["as", "bn", "gu", "hi", "kn", "ml", "mr", "or", "pa", "ta", "te"]
10
+ # we will be testing the overlaps of training data with all these benchmarks
11
+ # benchmarks = ['wat2021-devtest', 'wat2020-devtest', 'wat-2018', 'wmt-news', 'ufal-ta', 'pmi']
12
+
13
+
14
+ def read_lines(path):
15
+ # if path doesnt exist, return empty list
16
+ if not os.path.exists(path):
17
+ return []
18
+ with open(path, "r") as f:
19
+ lines = f.readlines()
20
+ return lines
21
+
22
+
23
+ def create_txt(outFile, lines):
24
+ add_newline = not "\n" in lines[0]
25
+ outfile = open("{0}".format(outFile), "w")
26
+ for line in lines:
27
+ if add_newline:
28
+ outfile.write(line + "\n")
29
+ else:
30
+ outfile.write(line)
31
+
32
+ outfile.close()
33
+
34
+
35
+ def pair_dedup_files(src_file, tgt_file):
36
+ src_lines = read_lines(src_file)
37
+ tgt_lines = read_lines(tgt_file)
38
+ len_before = len(src_lines)
39
+
40
+ src_dedupped, tgt_dedupped = pair_dedup_lists(src_lines, tgt_lines)
41
+
42
+ len_after = len(src_dedupped)
43
+ num_duplicates = len_before - len_after
44
+
45
+ print(f"Dropped duplicate pairs in {src_file} Num duplicates -> {num_duplicates}")
46
+ create_txt(src_file, src_dedupped)
47
+ create_txt(tgt_file, tgt_dedupped)
48
+
49
+
50
+ def pair_dedup_lists(src_list, tgt_list):
51
+ src_tgt = list(set(zip(src_list, tgt_list)))
52
+ src_deduped, tgt_deduped = zip(*src_tgt)
53
+ return src_deduped, tgt_deduped
54
+
55
+
56
+ def strip_and_normalize(line):
57
+ # lowercase line, remove spaces and strip punctuation
58
+
59
+ # one of the fastest way to add an exclusion list and remove that
60
+ # list of characters from a string
61
+ # https://towardsdatascience.com/how-to-efficiently-remove-punctuations-from-a-string-899ad4a059fb
62
+ exclist = string.punctuation + "\u0964"
63
+ table_ = str.maketrans("", "", exclist)
64
+
65
+ line = line.replace(" ", "").lower()
66
+ # dont use this method, it is painfully slow
67
+ # line = "".join([i for i in line if i not in string.punctuation])
68
+ line = line.translate(table_)
69
+ return line
70
+
71
+
72
+ def expand_tupled_list(list_of_tuples):
73
+ # convert list of tuples into two lists
74
+ # https://stackoverflow.com/questions/8081545/how-to-convert-list-of-tuples-to-multiple-lists
75
+ # [(en, as), (as, bn), (bn, gu)] - > [en, as, bn], [as, bn, gu]
76
+ list_a, list_b = map(list, zip(*list_of_tuples))
77
+ return list_a, list_b
78
+
79
+
80
+ def get_src_tgt_lang_lists(many2many=False):
81
+ if many2many is False:
82
+ SRC_LANGS = ["en"]
83
+ TGT_LANGS = INDIC_LANGS
84
+ else:
85
+ all_languages = INDIC_LANGS + ["en"]
86
+ # lang_pairs = list(permutations(all_languages, 2))
87
+
88
+ SRC_LANGS, TGT_LANGS = all_languages, all_languages
89
+
90
+ return SRC_LANGS, TGT_LANGS
91
+
92
+
93
+ def normalize_and_gather_all_benchmarks(devtest_dir, many2many=False):
94
+
95
+ # This is a dict of dict of lists
96
+ # the first keys are for lang-pair, the second keys are for src/tgt
97
+ # the values are the devtest lines.
98
+ # so devtest_pairs_normalized[en-as][src] will store src(en lines)
99
+ # so devtest_pairs_normalized[en-as][tgt] will store tgt(as lines)
100
+ devtest_pairs_normalized = defaultdict(lambda: defaultdict(list))
101
+ SRC_LANGS, TGT_LANGS = get_src_tgt_lang_lists(many2many)
102
+ benchmarks = os.listdir(devtest_dir)
103
+ for dataset in benchmarks:
104
+ for src_lang in SRC_LANGS:
105
+ for tgt_lang in TGT_LANGS:
106
+ if src_lang == tgt_lang:
107
+ continue
108
+ if dataset == "wat2021-devtest":
109
+ # wat2021 dev and test sets have differnet folder structure
110
+ src_dev = read_lines(f"{devtest_dir}/{dataset}/dev.{src_lang}")
111
+ tgt_dev = read_lines(f"{devtest_dir}/{dataset}/dev.{tgt_lang}")
112
+ src_test = read_lines(f"{devtest_dir}/{dataset}/test.{src_lang}")
113
+ tgt_test = read_lines(f"{devtest_dir}/{dataset}/test.{tgt_lang}")
114
+ else:
115
+ src_dev = read_lines(
116
+ f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/dev.{src_lang}"
117
+ )
118
+ tgt_dev = read_lines(
119
+ f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/dev.{tgt_lang}"
120
+ )
121
+ src_test = read_lines(
122
+ f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/test.{src_lang}"
123
+ )
124
+ tgt_test = read_lines(
125
+ f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/test.{tgt_lang}"
126
+ )
127
+
128
+ # if the tgt_pair data doesnt exist for a particular test set,
129
+ # it will be an empty list
130
+ if tgt_test == [] or tgt_dev == []:
131
+ # print(f'{dataset} does not have {src_lang}-{tgt_lang} data')
132
+ continue
133
+
134
+ # combine both dev and test sets into one
135
+ src_devtest = src_dev + src_test
136
+ tgt_devtest = tgt_dev + tgt_test
137
+
138
+ src_devtest = [strip_and_normalize(line) for line in src_devtest]
139
+ tgt_devtest = [strip_and_normalize(line) for line in tgt_devtest]
140
+
141
+ devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["src"].extend(
142
+ src_devtest
143
+ )
144
+ devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["tgt"].extend(
145
+ tgt_devtest
146
+ )
147
+
148
+ # dedup merged benchmark datasets
149
+ for src_lang in SRC_LANGS:
150
+ for tgt_lang in TGT_LANGS:
151
+ if src_lang == tgt_lang:
152
+ continue
153
+ src_devtest, tgt_devtest = (
154
+ devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["src"],
155
+ devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["tgt"],
156
+ )
157
+ # if the devtest data doesnt exist for the src-tgt pair then continue
158
+ if src_devtest == [] or tgt_devtest == []:
159
+ continue
160
+ src_devtest, tgt_devtest = pair_dedup_lists(src_devtest, tgt_devtest)
161
+ (
162
+ devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["src"],
163
+ devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["tgt"],
164
+ ) = (
165
+ src_devtest,
166
+ tgt_devtest,
167
+ )
168
+
169
+ return devtest_pairs_normalized
170
+
171
+
172
+ def remove_train_devtest_overlaps(train_dir, devtest_dir, many2many=False):
173
+
174
+ devtest_pairs_normalized = normalize_and_gather_all_benchmarks(
175
+ devtest_dir, many2many
176
+ )
177
+
178
+ SRC_LANGS, TGT_LANGS = get_src_tgt_lang_lists(many2many)
179
+
180
+ if not many2many:
181
+ all_src_sentences_normalized = []
182
+ for key in devtest_pairs_normalized:
183
+ all_src_sentences_normalized.extend(devtest_pairs_normalized[key]["src"])
184
+ # remove all duplicates. Now this contains all the normalized
185
+ # english sentences in all test benchmarks across all lang pair
186
+ all_src_sentences_normalized = list(set(all_src_sentences_normalized))
187
+ else:
188
+ all_src_sentences_normalized = None
189
+
190
+ src_overlaps = []
191
+ tgt_overlaps = []
192
+ for src_lang in SRC_LANGS:
193
+ for tgt_lang in TGT_LANGS:
194
+ if src_lang == tgt_lang:
195
+ continue
196
+ new_src_train = []
197
+ new_tgt_train = []
198
+
199
+ pair = f"{src_lang}-{tgt_lang}"
200
+ src_train = read_lines(f"{train_dir}/{pair}/train.{src_lang}")
201
+ tgt_train = read_lines(f"{train_dir}/{pair}/train.{tgt_lang}")
202
+
203
+ len_before = len(src_train)
204
+ if len_before == 0:
205
+ continue
206
+
207
+ src_train_normalized = [strip_and_normalize(line) for line in src_train]
208
+ tgt_train_normalized = [strip_and_normalize(line) for line in tgt_train]
209
+
210
+ if all_src_sentences_normalized:
211
+ src_devtest_normalized = all_src_sentences_normalized
212
+ else:
213
+ src_devtest_normalized = devtest_pairs_normalized[pair]["src"]
214
+
215
+ tgt_devtest_normalized = devtest_pairs_normalized[pair]["tgt"]
216
+
217
+ # compute all src and tgt super strict overlaps for a lang pair
218
+ overlaps = set(src_train_normalized) & set(src_devtest_normalized)
219
+ src_overlaps.extend(list(overlaps))
220
+
221
+ overlaps = set(tgt_train_normalized) & set(tgt_devtest_normalized)
222
+ tgt_overlaps.extend(list(overlaps))
223
+ # dictionaries offer o(1) lookup
224
+ src_overlaps_dict = {}
225
+ tgt_overlaps_dict = {}
226
+ for line in src_overlaps:
227
+ src_overlaps_dict[line] = 1
228
+ for line in tgt_overlaps:
229
+ tgt_overlaps_dict[line] = 1
230
+
231
+ # loop to remove the ovelapped data
232
+ idx = -1
233
+ for src_line_norm, tgt_line_norm in tqdm(
234
+ zip(src_train_normalized, tgt_train_normalized), total=len_before
235
+ ):
236
+ idx += 1
237
+ if src_overlaps_dict.get(src_line_norm, None):
238
+ continue
239
+ if tgt_overlaps_dict.get(tgt_line_norm, None):
240
+ continue
241
+ new_src_train.append(src_train[idx])
242
+ new_tgt_train.append(tgt_train[idx])
243
+
244
+ len_after = len(new_src_train)
245
+ print(
246
+ f"Detected overlaps between train and devetest for {pair} is {len_before - len_after}"
247
+ )
248
+ print(f"saving new files at {train_dir}/{pair}/")
249
+ create_txt(f"{train_dir}/{pair}/train.{src_lang}", new_src_train)
250
+ create_txt(f"{train_dir}/{pair}/train.{tgt_lang}", new_tgt_train)
251
+
252
+
253
+ if __name__ == "__main__":
254
+ train_data_dir = sys.argv[1]
255
+ # benchmarks directory should contains all the test sets
256
+ devtest_data_dir = sys.argv[2]
257
+ if len(sys.argv) == 3:
258
+ many2many = False
259
+ elif len(sys.argv) == 4:
260
+ many2many = sys.argv[4]
261
+ if many2many.lower() == "true":
262
+ many2many = True
263
+ else:
264
+ many2many = False
265
+ remove_train_devtest_overlaps(train_data_dir, devtest_data_dir, many2many)