Spaces:
Runtime error
Runtime error
Hussain Shaikh
commited on
Commit
·
7edceed
1
Parent(s):
6325f49
final commit added required files
Browse files- .gitignore +143 -0
- api/api.py +152 -0
- api/punctuate.py +220 -0
- app.py +27 -4
- inference/__init__.py +0 -0
- inference/custom_interactive.py +298 -0
- inference/engine.py +198 -0
- legacy/apply_bpe_test_valid_notag.sh +33 -0
- legacy/apply_bpe_train_notag.sh +33 -0
- legacy/env.sh +17 -0
- legacy/indictrans_workflow.ipynb +643 -0
- legacy/install_fairseq.sh +45 -0
- legacy/run_inference.sh +80 -0
- legacy/run_joint_inference.sh +74 -0
- legacy/tpu_training_instructions.md +92 -0
- legacy/translate.sh +70 -0
- model_configs/__init__.py +1 -0
- model_configs/custom_transformer.py +38 -0
- requirements.txt +11 -0
- scripts/__init__.py +0 -0
- scripts/add_joint_tags_translate.py +61 -0
- scripts/add_tags_translate.py +33 -0
- scripts/clean_vocab.py +19 -0
- scripts/concat_joint_data.py +130 -0
- scripts/extract_non_english_pairs.py +108 -0
- scripts/postprocess_score.py +48 -0
- scripts/postprocess_translate.py +110 -0
- scripts/preprocess_translate.py +172 -0
- scripts/remove_large_sentences.py +44 -0
- scripts/remove_train_devtest_overlaps.py +265 -0
.gitignore
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ignore libs folder we use
|
2 |
+
indic_nlp_library
|
3 |
+
indic_nlp_resources
|
4 |
+
subword-nmt
|
5 |
+
|
6 |
+
# Byte-compiled / optimized / DLL files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
|
11 |
+
# C extensions
|
12 |
+
*.so
|
13 |
+
|
14 |
+
# Distribution / packaging
|
15 |
+
.Python
|
16 |
+
build/
|
17 |
+
develop-eggs/
|
18 |
+
dist/
|
19 |
+
downloads/
|
20 |
+
eggs/
|
21 |
+
.eggs/
|
22 |
+
lib/
|
23 |
+
lib64/
|
24 |
+
parts/
|
25 |
+
sdist/
|
26 |
+
var/
|
27 |
+
wheels/
|
28 |
+
share/python-wheels/
|
29 |
+
*.egg-info/
|
30 |
+
.installed.cfg
|
31 |
+
*.egg
|
32 |
+
MANIFEST
|
33 |
+
|
34 |
+
# PyInstaller
|
35 |
+
# Usually these files are written by a python script from a template
|
36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
37 |
+
*.manifest
|
38 |
+
*.spec
|
39 |
+
|
40 |
+
# Installer logs
|
41 |
+
pip-log.txt
|
42 |
+
pip-delete-this-directory.txt
|
43 |
+
|
44 |
+
# Unit test / coverage reports
|
45 |
+
htmlcov/
|
46 |
+
.tox/
|
47 |
+
.nox/
|
48 |
+
.coverage
|
49 |
+
.coverage.*
|
50 |
+
.cache
|
51 |
+
nosetests.xml
|
52 |
+
coverage.xml
|
53 |
+
*.cover
|
54 |
+
*.py,cover
|
55 |
+
.hypothesis/
|
56 |
+
.pytest_cache/
|
57 |
+
cover/
|
58 |
+
|
59 |
+
# Translations
|
60 |
+
*.mo
|
61 |
+
*.pot
|
62 |
+
|
63 |
+
# Django stuff:
|
64 |
+
*.log
|
65 |
+
local_settings.py
|
66 |
+
db.sqlite3
|
67 |
+
db.sqlite3-journal
|
68 |
+
|
69 |
+
# Flask stuff:
|
70 |
+
instance/
|
71 |
+
.webassets-cache
|
72 |
+
|
73 |
+
# Scrapy stuff:
|
74 |
+
.scrapy
|
75 |
+
|
76 |
+
# Sphinx documentation
|
77 |
+
docs/_build/
|
78 |
+
|
79 |
+
# PyBuilder
|
80 |
+
.pybuilder/
|
81 |
+
target/
|
82 |
+
|
83 |
+
# Jupyter Notebook
|
84 |
+
.ipynb_checkpoints
|
85 |
+
|
86 |
+
# IPython
|
87 |
+
profile_default/
|
88 |
+
ipython_config.py
|
89 |
+
|
90 |
+
# pyenv
|
91 |
+
# For a library or package, you might want to ignore these files since the code is
|
92 |
+
# intended to run in multiple environments; otherwise, check them in:
|
93 |
+
# .python-version
|
94 |
+
|
95 |
+
# pipenv
|
96 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
97 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
98 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
99 |
+
# install all needed dependencies.
|
100 |
+
#Pipfile.lock
|
101 |
+
|
102 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
103 |
+
__pypackages__/
|
104 |
+
|
105 |
+
# Celery stuff
|
106 |
+
celerybeat-schedule
|
107 |
+
celerybeat.pid
|
108 |
+
|
109 |
+
# SageMath parsed files
|
110 |
+
*.sage.py
|
111 |
+
|
112 |
+
# Environments
|
113 |
+
.env
|
114 |
+
.venv
|
115 |
+
env/
|
116 |
+
venv/
|
117 |
+
ENV/
|
118 |
+
env.bak/
|
119 |
+
venv.bak/
|
120 |
+
|
121 |
+
# Spyder project settings
|
122 |
+
.spyderproject
|
123 |
+
.spyproject
|
124 |
+
|
125 |
+
# Rope project settings
|
126 |
+
.ropeproject
|
127 |
+
|
128 |
+
# mkdocs documentation
|
129 |
+
/site
|
130 |
+
|
131 |
+
# mypy
|
132 |
+
.mypy_cache/
|
133 |
+
.dmypy.json
|
134 |
+
dmypy.json
|
135 |
+
|
136 |
+
# Pyre type checker
|
137 |
+
.pyre/
|
138 |
+
|
139 |
+
# pytype static type analyzer
|
140 |
+
.pytype/
|
141 |
+
|
142 |
+
# Cython debug symbols
|
143 |
+
cython_debug/
|
api/api.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import re
|
4 |
+
from math import floor, ceil
|
5 |
+
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
|
6 |
+
# from nltk.tokenize import sent_tokenize
|
7 |
+
from flask import Flask, request, jsonify
|
8 |
+
from flask_cors import CORS, cross_origin
|
9 |
+
import webvtt
|
10 |
+
from io import StringIO
|
11 |
+
from mosestokenizer import MosesSentenceSplitter
|
12 |
+
|
13 |
+
from indicTrans.inference.engine import Model
|
14 |
+
from punctuate import RestorePuncts
|
15 |
+
from indicnlp.tokenize.sentence_tokenize import sentence_split
|
16 |
+
|
17 |
+
app = Flask(__name__)
|
18 |
+
cors = CORS(app)
|
19 |
+
app.config['CORS_HEADERS'] = 'Content-Type'
|
20 |
+
|
21 |
+
indic2en_model = Model(expdir='models/v3/indic-en')
|
22 |
+
en2indic_model = Model(expdir='models/v3/en-indic')
|
23 |
+
m2m_model = Model(expdir='models/m2m')
|
24 |
+
|
25 |
+
rpunct = RestorePuncts()
|
26 |
+
|
27 |
+
indic_language_dict = {
|
28 |
+
'Assamese': 'as',
|
29 |
+
'Hindi' : 'hi',
|
30 |
+
'Marathi' : 'mr',
|
31 |
+
'Tamil' : 'ta',
|
32 |
+
'Bengali' : 'bn',
|
33 |
+
'Kannada' : 'kn',
|
34 |
+
'Oriya' : 'or',
|
35 |
+
'Telugu' : 'te',
|
36 |
+
'Gujarati' : 'gu',
|
37 |
+
'Malayalam' : 'ml',
|
38 |
+
'Punjabi' : 'pa',
|
39 |
+
}
|
40 |
+
|
41 |
+
splitter = MosesSentenceSplitter('en')
|
42 |
+
|
43 |
+
def get_inference_params():
|
44 |
+
source_language = request.form['source_language']
|
45 |
+
target_language = request.form['target_language']
|
46 |
+
|
47 |
+
if source_language in indic_language_dict and target_language == 'English':
|
48 |
+
model = indic2en_model
|
49 |
+
source_lang = indic_language_dict[source_language]
|
50 |
+
target_lang = 'en'
|
51 |
+
elif source_language == 'English' and target_language in indic_language_dict:
|
52 |
+
model = en2indic_model
|
53 |
+
source_lang = 'en'
|
54 |
+
target_lang = indic_language_dict[target_language]
|
55 |
+
elif source_language in indic_language_dict and target_language in indic_language_dict:
|
56 |
+
model = m2m_model
|
57 |
+
source_lang = indic_language_dict[source_language]
|
58 |
+
target_lang = indic_language_dict[target_language]
|
59 |
+
|
60 |
+
return model, source_lang, target_lang
|
61 |
+
|
62 |
+
@app.route('/', methods=['GET'])
|
63 |
+
def main():
|
64 |
+
return "IndicTrans API"
|
65 |
+
|
66 |
+
@app.route('/supported_languages', methods=['GET'])
|
67 |
+
@cross_origin()
|
68 |
+
def supported_languages():
|
69 |
+
return jsonify(indic_language_dict)
|
70 |
+
|
71 |
+
@app.route("/translate", methods=['POST'])
|
72 |
+
@cross_origin()
|
73 |
+
def infer_indic_en():
|
74 |
+
model, source_lang, target_lang = get_inference_params()
|
75 |
+
source_text = request.form['text']
|
76 |
+
|
77 |
+
start_time = time.time()
|
78 |
+
target_text = model.translate_paragraph(source_text, source_lang, target_lang)
|
79 |
+
end_time = time.time()
|
80 |
+
return {'text':target_text, 'duration':round(end_time-start_time, 2)}
|
81 |
+
|
82 |
+
@app.route("/translate_vtt", methods=['POST'])
|
83 |
+
@cross_origin()
|
84 |
+
def infer_vtt_indic_en():
|
85 |
+
start_time = time.time()
|
86 |
+
model, source_lang, target_lang = get_inference_params()
|
87 |
+
source_text = request.form['text']
|
88 |
+
# vad_segments = request.form['vad_nochunk'] # Assuming it is an array of start & end timestamps
|
89 |
+
|
90 |
+
vad = webvtt.read_buffer(StringIO(source_text))
|
91 |
+
source_sentences = [v.text.replace('\r', '').replace('\n', ' ') for v in vad]
|
92 |
+
|
93 |
+
## SUMANTH LOGIC HERE ##
|
94 |
+
|
95 |
+
# for each vad timestamp, do:
|
96 |
+
large_sentence = ' '.join(source_sentences) # only sentences in that time range
|
97 |
+
large_sentence = large_sentence.lower()
|
98 |
+
# split_sents = sentence_split(large_sentence, 'en')
|
99 |
+
# print(split_sents)
|
100 |
+
|
101 |
+
large_sentence = re.sub(r'[^\w\s]', '', large_sentence)
|
102 |
+
punctuated = rpunct.punctuate(large_sentence, batch_size=32)
|
103 |
+
end_time = time.time()
|
104 |
+
print("Time Taken for punctuation: {} s".format(end_time - start_time))
|
105 |
+
start_time = time.time()
|
106 |
+
split_sents = splitter([punctuated]) ### Please uncomment
|
107 |
+
|
108 |
+
|
109 |
+
# print(split_sents)
|
110 |
+
# output_sentence_punctuated = model.translate_paragraph(punctuated, source_lang, target_lang)
|
111 |
+
output_sents = model.batch_translate(split_sents, source_lang, target_lang)
|
112 |
+
# print(output_sents)
|
113 |
+
# output_sents = split_sents
|
114 |
+
# print(output_sents)
|
115 |
+
# align this to those range of source_sentences in `captions`
|
116 |
+
|
117 |
+
map_ = {split_sents[i] : output_sents[i] for i in range(len(split_sents))}
|
118 |
+
# print(map_)
|
119 |
+
punct_para = ' '.join(list(map_.keys()))
|
120 |
+
nmt_para = ' '.join(list(map_.values()))
|
121 |
+
nmt_words = nmt_para.split(' ')
|
122 |
+
|
123 |
+
len_punct = len(punct_para.split(' '))
|
124 |
+
len_nmt = len(nmt_para.split(' '))
|
125 |
+
|
126 |
+
start = 0
|
127 |
+
for i in range(len(vad)):
|
128 |
+
if vad[i].text == '':
|
129 |
+
continue
|
130 |
+
|
131 |
+
len_caption = len(vad[i].text.split(' '))
|
132 |
+
frac = (len_caption / len_punct)
|
133 |
+
# frac = round(frac, 2)
|
134 |
+
|
135 |
+
req_nmt_size = floor(frac * len_nmt)
|
136 |
+
# print(frac, req_nmt_size)
|
137 |
+
|
138 |
+
vad[i].text = ' '.join(nmt_words[start:start+req_nmt_size])
|
139 |
+
# print(vad[i].text)
|
140 |
+
# print(start, req_nmt_size)
|
141 |
+
start += req_nmt_size
|
142 |
+
|
143 |
+
end_time = time.time()
|
144 |
+
|
145 |
+
print("Time Taken for translation: {} s".format(end_time - start_time))
|
146 |
+
|
147 |
+
# vad.save('aligned.vtt')
|
148 |
+
|
149 |
+
return {
|
150 |
+
'text': vad.content,
|
151 |
+
# 'duration':round(end_time-start_time, 2)
|
152 |
+
}
|
api/punctuate.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# 💾⚙️🔮
|
3 |
+
|
4 |
+
# taken from https://github.com/Felflare/rpunct/blob/master/rpunct/punctuate.py
|
5 |
+
# modified to support batching during gpu inference
|
6 |
+
|
7 |
+
|
8 |
+
__author__ = "Daulet N."
|
9 |
+
__email__ = "daulet.nurmanbetov@gmail.com"
|
10 |
+
|
11 |
+
import time
|
12 |
+
import logging
|
13 |
+
import webvtt
|
14 |
+
import torch
|
15 |
+
from io import StringIO
|
16 |
+
from nltk.tokenize import sent_tokenize
|
17 |
+
#from langdetect import detect
|
18 |
+
from simpletransformers.ner import NERModel
|
19 |
+
|
20 |
+
|
21 |
+
class RestorePuncts:
|
22 |
+
def __init__(self, wrds_per_pred=250):
|
23 |
+
self.wrds_per_pred = wrds_per_pred
|
24 |
+
self.overlap_wrds = 30
|
25 |
+
self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U']
|
26 |
+
self.model = NERModel("bert", "felflare/bert-restore-punctuation", labels=self.valid_labels,
|
27 |
+
args={"silent": True, "max_seq_length": 512})
|
28 |
+
# use_cuda isnt working and this hack seems to load the model correctly to the gpu
|
29 |
+
self.model.device = torch.device("cuda:1")
|
30 |
+
# dummy punctuate to load the model onto gpu
|
31 |
+
self.punctuate("hello how are you")
|
32 |
+
|
33 |
+
def punctuate(self, text: str, batch_size:int=32, lang:str=''):
|
34 |
+
"""
|
35 |
+
Performs punctuation restoration on arbitrarily large text.
|
36 |
+
Detects if input is not English, if non-English was detected terminates predictions.
|
37 |
+
Overrride by supplying `lang='en'`
|
38 |
+
|
39 |
+
Args:
|
40 |
+
- text (str): Text to punctuate, can be few words to as large as you want.
|
41 |
+
- lang (str): Explicit language of input text.
|
42 |
+
"""
|
43 |
+
#if not lang and len(text) > 10:
|
44 |
+
# lang = detect(text)
|
45 |
+
#if lang != 'en':
|
46 |
+
# raise Exception(F"""Non English text detected. Restore Punctuation works only for English.
|
47 |
+
# If you are certain the input is English, pass argument lang='en' to this function.
|
48 |
+
# Punctuate received: {text}""")
|
49 |
+
|
50 |
+
def chunks(L, n):
|
51 |
+
return [L[x : x + n] for x in range(0, len(L), n)]
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
# plit up large text into bert digestable chunks
|
56 |
+
splits = self.split_on_toks(text, self.wrds_per_pred, self.overlap_wrds)
|
57 |
+
|
58 |
+
texts = [i["text"] for i in splits]
|
59 |
+
batches = chunks(texts, batch_size)
|
60 |
+
preds_lst = []
|
61 |
+
|
62 |
+
|
63 |
+
for batch in batches:
|
64 |
+
batch_preds, _ = self.model.predict(batch)
|
65 |
+
preds_lst.extend(batch_preds)
|
66 |
+
|
67 |
+
|
68 |
+
# predict slices
|
69 |
+
# full_preds_lst contains tuple of labels and logits
|
70 |
+
#full_preds_lst = [self.predict(i['text']) for i in splits]
|
71 |
+
# extract predictions, and discard logits
|
72 |
+
#preds_lst = [i[0][0] for i in full_preds_lst]
|
73 |
+
# join text slices
|
74 |
+
combined_preds = self.combine_results(text, preds_lst)
|
75 |
+
# create punctuated prediction
|
76 |
+
punct_text = self.punctuate_texts(combined_preds)
|
77 |
+
return punct_text
|
78 |
+
|
79 |
+
def predict(self, input_slice):
|
80 |
+
"""
|
81 |
+
Passes the unpunctuated text to the model for punctuation.
|
82 |
+
"""
|
83 |
+
predictions, raw_outputs = self.model.predict([input_slice])
|
84 |
+
return predictions, raw_outputs
|
85 |
+
|
86 |
+
@staticmethod
|
87 |
+
def split_on_toks(text, length, overlap):
|
88 |
+
"""
|
89 |
+
Splits text into predefined slices of overlapping text with indexes (offsets)
|
90 |
+
that tie-back to original text.
|
91 |
+
This is done to bypass 512 token limit on transformer models by sequentially
|
92 |
+
feeding chunks of < 512 toks.
|
93 |
+
Example output:
|
94 |
+
[{...}, {"text": "...", 'start_idx': 31354, 'end_idx': 32648}, {...}]
|
95 |
+
"""
|
96 |
+
wrds = text.replace('\n', ' ').split(" ")
|
97 |
+
resp = []
|
98 |
+
lst_chunk_idx = 0
|
99 |
+
i = 0
|
100 |
+
|
101 |
+
while True:
|
102 |
+
# words in the chunk and the overlapping portion
|
103 |
+
wrds_len = wrds[(length * i):(length * (i + 1))]
|
104 |
+
wrds_ovlp = wrds[(length * (i + 1)):((length * (i + 1)) + overlap)]
|
105 |
+
wrds_split = wrds_len + wrds_ovlp
|
106 |
+
|
107 |
+
# Break loop if no more words
|
108 |
+
if not wrds_split:
|
109 |
+
break
|
110 |
+
|
111 |
+
wrds_str = " ".join(wrds_split)
|
112 |
+
nxt_chunk_start_idx = len(" ".join(wrds_len))
|
113 |
+
lst_char_idx = len(" ".join(wrds_split))
|
114 |
+
|
115 |
+
resp_obj = {
|
116 |
+
"text": wrds_str,
|
117 |
+
"start_idx": lst_chunk_idx,
|
118 |
+
"end_idx": lst_char_idx + lst_chunk_idx,
|
119 |
+
}
|
120 |
+
|
121 |
+
resp.append(resp_obj)
|
122 |
+
lst_chunk_idx += nxt_chunk_start_idx + 1
|
123 |
+
i += 1
|
124 |
+
logging.info(f"Sliced transcript into {len(resp)} slices.")
|
125 |
+
return resp
|
126 |
+
|
127 |
+
@staticmethod
|
128 |
+
def combine_results(full_text: str, text_slices):
|
129 |
+
"""
|
130 |
+
Given a full text and predictions of each slice combines predictions into a single text again.
|
131 |
+
Performs validataion wether text was combined correctly
|
132 |
+
"""
|
133 |
+
split_full_text = full_text.replace('\n', ' ').split(" ")
|
134 |
+
split_full_text = [i for i in split_full_text if i]
|
135 |
+
split_full_text_len = len(split_full_text)
|
136 |
+
output_text = []
|
137 |
+
index = 0
|
138 |
+
|
139 |
+
if len(text_slices[-1]) <= 3 and len(text_slices) > 1:
|
140 |
+
text_slices = text_slices[:-1]
|
141 |
+
|
142 |
+
for _slice in text_slices:
|
143 |
+
slice_wrds = len(_slice)
|
144 |
+
for ix, wrd in enumerate(_slice):
|
145 |
+
# print(index, "|", str(list(wrd.keys())[0]), "|", split_full_text[index])
|
146 |
+
if index == split_full_text_len:
|
147 |
+
break
|
148 |
+
|
149 |
+
if split_full_text[index] == str(list(wrd.keys())[0]) and \
|
150 |
+
ix <= slice_wrds - 3 and text_slices[-1] != _slice:
|
151 |
+
index += 1
|
152 |
+
pred_item_tuple = list(wrd.items())[0]
|
153 |
+
output_text.append(pred_item_tuple)
|
154 |
+
elif split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == _slice:
|
155 |
+
index += 1
|
156 |
+
pred_item_tuple = list(wrd.items())[0]
|
157 |
+
output_text.append(pred_item_tuple)
|
158 |
+
assert [i[0] for i in output_text] == split_full_text
|
159 |
+
return output_text
|
160 |
+
|
161 |
+
@staticmethod
|
162 |
+
def punctuate_texts(full_pred: list):
|
163 |
+
"""
|
164 |
+
Given a list of Predictions from the model, applies the predictions to text,
|
165 |
+
thus punctuating it.
|
166 |
+
"""
|
167 |
+
punct_resp = ""
|
168 |
+
for i in full_pred:
|
169 |
+
word, label = i
|
170 |
+
if label[-1] == "U":
|
171 |
+
punct_wrd = word.capitalize()
|
172 |
+
else:
|
173 |
+
punct_wrd = word
|
174 |
+
|
175 |
+
if label[0] != "O":
|
176 |
+
punct_wrd += label[0]
|
177 |
+
|
178 |
+
punct_resp += punct_wrd + " "
|
179 |
+
punct_resp = punct_resp.strip()
|
180 |
+
# Append trailing period if doesnt exist.
|
181 |
+
if punct_resp[-1].isalnum():
|
182 |
+
punct_resp += "."
|
183 |
+
return punct_resp
|
184 |
+
|
185 |
+
|
186 |
+
if __name__ == "__main__":
|
187 |
+
|
188 |
+
start = time.time()
|
189 |
+
punct_model = RestorePuncts()
|
190 |
+
|
191 |
+
load_model = time.time()
|
192 |
+
print(f'Time to load model: {load_model - start}')
|
193 |
+
# read test file
|
194 |
+
# with open('en_lower.txt', 'r') as fp:
|
195 |
+
# # test_sample = fp.read()
|
196 |
+
# lines = fp.readlines()
|
197 |
+
|
198 |
+
with open('sample.vtt', 'r') as fp:
|
199 |
+
source_text = fp.read()
|
200 |
+
|
201 |
+
# captions = webvtt.read_buffer(StringIO(source_text))
|
202 |
+
captions = webvtt.read('sample.vtt')
|
203 |
+
source_sentences = [caption.text.replace('\r', '').replace('\n', ' ') for caption in captions]
|
204 |
+
|
205 |
+
# print(source_sentences)
|
206 |
+
|
207 |
+
sent = ' '.join(source_sentences)
|
208 |
+
punctuated = punct_model.punctuate(sent)
|
209 |
+
|
210 |
+
tokenised = sent_tokenize(punctuated)
|
211 |
+
# print(tokenised)
|
212 |
+
|
213 |
+
for i in range(len(tokenised)):
|
214 |
+
captions[i].text = tokenised[i]
|
215 |
+
# return captions.content
|
216 |
+
captions.save('my_captions.vtt')
|
217 |
+
|
218 |
+
end = time.time()
|
219 |
+
print(f'Time for run: {end - load_model}')
|
220 |
+
print(f'Total time: {end - start}')
|
app.py
CHANGED
@@ -1,7 +1,30 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
4 |
-
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
import gradio as gr
|
3 |
|
4 |
+
download="wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1IpcnaQ2ScX_zodt2aLlXa_5Kkntl0nue' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\\n/p')&id=1IpcnaQ2ScX_zodt2aLlXa_5Kkntl0nue\" -O en-indic.zip && rm -rf /tmp/cookies.txt"
|
5 |
+
os.system(download)
|
6 |
+
os.system('unzip /home/user/app/en-indic.zip')
|
7 |
|
8 |
+
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
|
9 |
+
import gradio as gr
|
10 |
+
from inference.engine import Model
|
11 |
+
indic2en_model = Model(expdir='/home/user/app/en-indic')
|
12 |
+
|
13 |
+
INDIC = {"Assamese": "as", "Bengali": "bn", "Gujarati": "gu", "Hindi": "hi","Kannada": "kn","Malayalam": "ml", "Marathi": "mr", "Odia": "or","Punjabi": "pa","Tamil": "ta", "Telugu" : "te"}
|
14 |
+
|
15 |
+
|
16 |
+
def translate(text, lang):
|
17 |
+
return indic2en_model.translate_paragraph(text, 'en', INDIC[lang])
|
18 |
+
|
19 |
+
|
20 |
+
languages = list(INDIC.keys())
|
21 |
+
drop_down = gr.inputs.Dropdown(languages, type="value", default="Hindi", label="Select Target Language")
|
22 |
+
text = gr.inputs.Textbox(lines=5, placeholder="Enter Text to translate", default="", label="Enter Text in English")
|
23 |
+
text_ouptut = gr.outputs.Textbox(type="auto", label="Translated text in Target Language")
|
24 |
+
|
25 |
+
# example=[['I want to translate this sentence in Hindi','Hindi'],
|
26 |
+
# ['I am feeling very good today.', 'Bengali']]
|
27 |
+
|
28 |
+
supported_lang = ', '.join(languages)
|
29 |
+
iface = gr.Interface(fn=translate, inputs=[text,drop_down] , outputs=text_ouptut, title='IndicTrans NMT System', description = 'Currently the model supports ' + supported_lang, article = 'Original repository can be found [here](https://github.com/AI4Bharat/indicTrans)' , examples=None)
|
30 |
+
iface.launch(enable_queue=True)
|
inference/__init__.py
ADDED
File without changes
|
inference/custom_interactive.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python wrapper for fairseq-interactive command line tool
|
2 |
+
|
3 |
+
#!/usr/bin/env python3 -u
|
4 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
5 |
+
#
|
6 |
+
# This source code is licensed under the MIT license found in the
|
7 |
+
# LICENSE file in the root directory of this source tree.
|
8 |
+
"""
|
9 |
+
Translate raw text with a trained model. Batches data on-the-fly.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import ast
|
13 |
+
from collections import namedtuple
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from fairseq import checkpoint_utils, options, tasks, utils
|
17 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
18 |
+
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
|
19 |
+
from fairseq_cli.generate import get_symbols_to_strip_from_output
|
20 |
+
|
21 |
+
import codecs
|
22 |
+
|
23 |
+
|
24 |
+
Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints")
|
25 |
+
Translation = namedtuple("Translation", "src_str hypos pos_scores alignments")
|
26 |
+
|
27 |
+
|
28 |
+
def make_batches(
|
29 |
+
lines, cfg, task, max_positions, encode_fn, constrainted_decoding=False
|
30 |
+
):
|
31 |
+
def encode_fn_target(x):
|
32 |
+
return encode_fn(x)
|
33 |
+
|
34 |
+
if constrainted_decoding:
|
35 |
+
# Strip (tab-delimited) contraints, if present, from input lines,
|
36 |
+
# store them in batch_constraints
|
37 |
+
batch_constraints = [list() for _ in lines]
|
38 |
+
for i, line in enumerate(lines):
|
39 |
+
if "\t" in line:
|
40 |
+
lines[i], *batch_constraints[i] = line.split("\t")
|
41 |
+
|
42 |
+
# Convert each List[str] to List[Tensor]
|
43 |
+
for i, constraint_list in enumerate(batch_constraints):
|
44 |
+
batch_constraints[i] = [
|
45 |
+
task.target_dictionary.encode_line(
|
46 |
+
encode_fn_target(constraint),
|
47 |
+
append_eos=False,
|
48 |
+
add_if_not_exist=False,
|
49 |
+
)
|
50 |
+
for constraint in constraint_list
|
51 |
+
]
|
52 |
+
|
53 |
+
if constrainted_decoding:
|
54 |
+
constraints_tensor = pack_constraints(batch_constraints)
|
55 |
+
else:
|
56 |
+
constraints_tensor = None
|
57 |
+
|
58 |
+
tokens, lengths = task.get_interactive_tokens_and_lengths(lines, encode_fn)
|
59 |
+
|
60 |
+
itr = task.get_batch_iterator(
|
61 |
+
dataset=task.build_dataset_for_inference(
|
62 |
+
tokens, lengths, constraints=constraints_tensor
|
63 |
+
),
|
64 |
+
max_tokens=cfg.dataset.max_tokens,
|
65 |
+
max_sentences=cfg.dataset.batch_size,
|
66 |
+
max_positions=max_positions,
|
67 |
+
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
|
68 |
+
).next_epoch_itr(shuffle=False)
|
69 |
+
for batch in itr:
|
70 |
+
ids = batch["id"]
|
71 |
+
src_tokens = batch["net_input"]["src_tokens"]
|
72 |
+
src_lengths = batch["net_input"]["src_lengths"]
|
73 |
+
constraints = batch.get("constraints", None)
|
74 |
+
|
75 |
+
yield Batch(
|
76 |
+
ids=ids,
|
77 |
+
src_tokens=src_tokens,
|
78 |
+
src_lengths=src_lengths,
|
79 |
+
constraints=constraints,
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
class Translator:
|
84 |
+
def __init__(
|
85 |
+
self, data_dir, checkpoint_path, batch_size=25, constrained_decoding=False
|
86 |
+
):
|
87 |
+
|
88 |
+
self.constrained_decoding = constrained_decoding
|
89 |
+
self.parser = options.get_generation_parser(interactive=True)
|
90 |
+
# buffer_size is currently not used but we just initialize it to batch
|
91 |
+
# size + 1 to avoid any assertion errors.
|
92 |
+
if self.constrained_decoding:
|
93 |
+
self.parser.set_defaults(
|
94 |
+
path=checkpoint_path,
|
95 |
+
remove_bpe="subword_nmt",
|
96 |
+
num_workers=-1,
|
97 |
+
constraints="ordered",
|
98 |
+
batch_size=batch_size,
|
99 |
+
buffer_size=batch_size + 1,
|
100 |
+
)
|
101 |
+
else:
|
102 |
+
self.parser.set_defaults(
|
103 |
+
path=checkpoint_path,
|
104 |
+
remove_bpe="subword_nmt",
|
105 |
+
num_workers=-1,
|
106 |
+
batch_size=batch_size,
|
107 |
+
buffer_size=batch_size + 1,
|
108 |
+
)
|
109 |
+
args = options.parse_args_and_arch(self.parser, input_args=[data_dir])
|
110 |
+
# we are explictly setting src_lang and tgt_lang here
|
111 |
+
# generally the data_dir we pass contains {split}-{src_lang}-{tgt_lang}.*.idx files from
|
112 |
+
# which fairseq infers the src and tgt langs(if these are not passed). In deployment we dont
|
113 |
+
# use any idx files and only store the SRC and TGT dictionaries.
|
114 |
+
args.source_lang = "SRC"
|
115 |
+
args.target_lang = "TGT"
|
116 |
+
# since we are truncating sentences to max_seq_len in engine, we can set it to False here
|
117 |
+
args.skip_invalid_size_inputs_valid_test = False
|
118 |
+
|
119 |
+
# we have custom architechtures in this folder and we will let fairseq
|
120 |
+
# import this
|
121 |
+
args.user_dir = "model_configs"
|
122 |
+
self.cfg = convert_namespace_to_omegaconf(args)
|
123 |
+
|
124 |
+
utils.import_user_module(self.cfg.common)
|
125 |
+
|
126 |
+
if self.cfg.interactive.buffer_size < 1:
|
127 |
+
self.cfg.interactive.buffer_size = 1
|
128 |
+
if self.cfg.dataset.max_tokens is None and self.cfg.dataset.batch_size is None:
|
129 |
+
self.cfg.dataset.batch_size = 1
|
130 |
+
|
131 |
+
assert (
|
132 |
+
not self.cfg.generation.sampling
|
133 |
+
or self.cfg.generation.nbest == self.cfg.generation.beam
|
134 |
+
), "--sampling requires --nbest to be equal to --beam"
|
135 |
+
assert (
|
136 |
+
not self.cfg.dataset.batch_size
|
137 |
+
or self.cfg.dataset.batch_size <= self.cfg.interactive.buffer_size
|
138 |
+
), "--batch-size cannot be larger than --buffer-size"
|
139 |
+
|
140 |
+
# Fix seed for stochastic decoding
|
141 |
+
# if self.cfg.common.seed is not None and not self.cfg.generation.no_seed_provided:
|
142 |
+
# np.random.seed(self.cfg.common.seed)
|
143 |
+
# utils.set_torch_seed(self.cfg.common.seed)
|
144 |
+
|
145 |
+
# if not self.constrained_decoding:
|
146 |
+
# self.use_cuda = torch.cuda.is_available() and not self.cfg.common.cpu
|
147 |
+
# else:
|
148 |
+
# self.use_cuda = False
|
149 |
+
|
150 |
+
self.use_cuda = torch.cuda.is_available() and not self.cfg.common.cpu
|
151 |
+
|
152 |
+
# Setup task, e.g., translation
|
153 |
+
self.task = tasks.setup_task(self.cfg.task)
|
154 |
+
|
155 |
+
# Load ensemble
|
156 |
+
overrides = ast.literal_eval(self.cfg.common_eval.model_overrides)
|
157 |
+
self.models, self._model_args = checkpoint_utils.load_model_ensemble(
|
158 |
+
utils.split_paths(self.cfg.common_eval.path),
|
159 |
+
arg_overrides=overrides,
|
160 |
+
task=self.task,
|
161 |
+
suffix=self.cfg.checkpoint.checkpoint_suffix,
|
162 |
+
strict=(self.cfg.checkpoint.checkpoint_shard_count == 1),
|
163 |
+
num_shards=self.cfg.checkpoint.checkpoint_shard_count,
|
164 |
+
)
|
165 |
+
|
166 |
+
# Set dictionaries
|
167 |
+
self.src_dict = self.task.source_dictionary
|
168 |
+
self.tgt_dict = self.task.target_dictionary
|
169 |
+
|
170 |
+
# Optimize ensemble for generation
|
171 |
+
for model in self.models:
|
172 |
+
if model is None:
|
173 |
+
continue
|
174 |
+
if self.cfg.common.fp16:
|
175 |
+
model.half()
|
176 |
+
if (
|
177 |
+
self.use_cuda
|
178 |
+
and not self.cfg.distributed_training.pipeline_model_parallel
|
179 |
+
):
|
180 |
+
model.cuda()
|
181 |
+
model.prepare_for_inference_(self.cfg)
|
182 |
+
|
183 |
+
# Initialize generator
|
184 |
+
self.generator = self.task.build_generator(self.models, self.cfg.generation)
|
185 |
+
|
186 |
+
# Handle tokenization and BPE
|
187 |
+
self.tokenizer = self.task.build_tokenizer(self.cfg.tokenizer)
|
188 |
+
self.bpe = self.task.build_bpe(self.cfg.bpe)
|
189 |
+
|
190 |
+
# Load alignment dictionary for unknown word replacement
|
191 |
+
# (None if no unknown word replacement, empty if no path to align dictionary)
|
192 |
+
self.align_dict = utils.load_align_dict(self.cfg.generation.replace_unk)
|
193 |
+
|
194 |
+
self.max_positions = utils.resolve_max_positions(
|
195 |
+
self.task.max_positions(), *[model.max_positions() for model in self.models]
|
196 |
+
)
|
197 |
+
|
198 |
+
def encode_fn(self, x):
|
199 |
+
if self.tokenizer is not None:
|
200 |
+
x = self.tokenizer.encode(x)
|
201 |
+
if self.bpe is not None:
|
202 |
+
x = self.bpe.encode(x)
|
203 |
+
return x
|
204 |
+
|
205 |
+
def decode_fn(self, x):
|
206 |
+
if self.bpe is not None:
|
207 |
+
x = self.bpe.decode(x)
|
208 |
+
if self.tokenizer is not None:
|
209 |
+
x = self.tokenizer.decode(x)
|
210 |
+
return x
|
211 |
+
|
212 |
+
def translate(self, inputs, constraints=None):
|
213 |
+
if self.constrained_decoding and constraints is None:
|
214 |
+
raise ValueError("Constraints cant be None in constrained decoding mode")
|
215 |
+
if not self.constrained_decoding and constraints is not None:
|
216 |
+
raise ValueError("Cannot pass constraints during normal translation")
|
217 |
+
if constraints:
|
218 |
+
constrained_decoding = True
|
219 |
+
modified_inputs = []
|
220 |
+
for _input, constraint in zip(inputs, constraints):
|
221 |
+
modified_inputs.append(_input + f"\t{constraint}")
|
222 |
+
inputs = modified_inputs
|
223 |
+
else:
|
224 |
+
constrained_decoding = False
|
225 |
+
|
226 |
+
start_id = 0
|
227 |
+
results = []
|
228 |
+
final_translations = []
|
229 |
+
for batch in make_batches(
|
230 |
+
inputs,
|
231 |
+
self.cfg,
|
232 |
+
self.task,
|
233 |
+
self.max_positions,
|
234 |
+
self.encode_fn,
|
235 |
+
constrained_decoding,
|
236 |
+
):
|
237 |
+
bsz = batch.src_tokens.size(0)
|
238 |
+
src_tokens = batch.src_tokens
|
239 |
+
src_lengths = batch.src_lengths
|
240 |
+
constraints = batch.constraints
|
241 |
+
if self.use_cuda:
|
242 |
+
src_tokens = src_tokens.cuda()
|
243 |
+
src_lengths = src_lengths.cuda()
|
244 |
+
if constraints is not None:
|
245 |
+
constraints = constraints.cuda()
|
246 |
+
|
247 |
+
sample = {
|
248 |
+
"net_input": {
|
249 |
+
"src_tokens": src_tokens,
|
250 |
+
"src_lengths": src_lengths,
|
251 |
+
},
|
252 |
+
}
|
253 |
+
|
254 |
+
translations = self.task.inference_step(
|
255 |
+
self.generator, self.models, sample, constraints=constraints
|
256 |
+
)
|
257 |
+
|
258 |
+
list_constraints = [[] for _ in range(bsz)]
|
259 |
+
if constrained_decoding:
|
260 |
+
list_constraints = [unpack_constraints(c) for c in constraints]
|
261 |
+
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
|
262 |
+
src_tokens_i = utils.strip_pad(src_tokens[i], self.tgt_dict.pad())
|
263 |
+
constraints = list_constraints[i]
|
264 |
+
results.append(
|
265 |
+
(
|
266 |
+
start_id + id,
|
267 |
+
src_tokens_i,
|
268 |
+
hypos,
|
269 |
+
{
|
270 |
+
"constraints": constraints,
|
271 |
+
},
|
272 |
+
)
|
273 |
+
)
|
274 |
+
|
275 |
+
# sort output to match input order
|
276 |
+
for id_, src_tokens, hypos, _ in sorted(results, key=lambda x: x[0]):
|
277 |
+
src_str = ""
|
278 |
+
if self.src_dict is not None:
|
279 |
+
src_str = self.src_dict.string(
|
280 |
+
src_tokens, self.cfg.common_eval.post_process
|
281 |
+
)
|
282 |
+
|
283 |
+
# Process top predictions
|
284 |
+
for hypo in hypos[: min(len(hypos), self.cfg.generation.nbest)]:
|
285 |
+
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
|
286 |
+
hypo_tokens=hypo["tokens"].int().cpu(),
|
287 |
+
src_str=src_str,
|
288 |
+
alignment=hypo["alignment"],
|
289 |
+
align_dict=self.align_dict,
|
290 |
+
tgt_dict=self.tgt_dict,
|
291 |
+
remove_bpe="subword_nmt",
|
292 |
+
extra_symbols_to_ignore=get_symbols_to_strip_from_output(
|
293 |
+
self.generator
|
294 |
+
),
|
295 |
+
)
|
296 |
+
detok_hypo_str = self.decode_fn(hypo_str)
|
297 |
+
final_translations.append(detok_hypo_str)
|
298 |
+
return final_translations
|
inference/engine.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import truncate
|
2 |
+
from sacremoses import MosesPunctNormalizer
|
3 |
+
from sacremoses import MosesTokenizer
|
4 |
+
from sacremoses import MosesDetokenizer
|
5 |
+
from subword_nmt.apply_bpe import BPE, read_vocabulary
|
6 |
+
import codecs
|
7 |
+
from tqdm import tqdm
|
8 |
+
from indicnlp.tokenize import indic_tokenize
|
9 |
+
from indicnlp.tokenize import indic_detokenize
|
10 |
+
from indicnlp.normalize import indic_normalize
|
11 |
+
from indicnlp.transliterate import unicode_transliterate
|
12 |
+
from mosestokenizer import MosesSentenceSplitter
|
13 |
+
from indicnlp.tokenize import sentence_tokenize
|
14 |
+
|
15 |
+
from inference.custom_interactive import Translator
|
16 |
+
|
17 |
+
|
18 |
+
INDIC = ["as", "bn", "gu", "hi", "kn", "ml", "mr", "or", "pa", "ta", "te"]
|
19 |
+
|
20 |
+
|
21 |
+
def split_sentences(paragraph, language):
|
22 |
+
if language == "en":
|
23 |
+
with MosesSentenceSplitter(language) as splitter:
|
24 |
+
return splitter([paragraph])
|
25 |
+
elif language in INDIC:
|
26 |
+
return sentence_tokenize.sentence_split(paragraph, lang=language)
|
27 |
+
|
28 |
+
|
29 |
+
def add_token(sent, tag_infos):
|
30 |
+
"""add special tokens specified by tag_infos to each element in list
|
31 |
+
|
32 |
+
tag_infos: list of tuples (tag_type,tag)
|
33 |
+
|
34 |
+
each tag_info results in a token of the form: __{tag_type}__{tag}__
|
35 |
+
|
36 |
+
"""
|
37 |
+
|
38 |
+
tokens = []
|
39 |
+
for tag_type, tag in tag_infos:
|
40 |
+
token = "__" + tag_type + "__" + tag + "__"
|
41 |
+
tokens.append(token)
|
42 |
+
|
43 |
+
return " ".join(tokens) + " " + sent
|
44 |
+
|
45 |
+
|
46 |
+
def apply_lang_tags(sents, src_lang, tgt_lang):
|
47 |
+
tagged_sents = []
|
48 |
+
for sent in sents:
|
49 |
+
tagged_sent = add_token(sent.strip(), [("src", src_lang), ("tgt", tgt_lang)])
|
50 |
+
tagged_sents.append(tagged_sent)
|
51 |
+
return tagged_sents
|
52 |
+
|
53 |
+
|
54 |
+
def truncate_long_sentences(sents):
|
55 |
+
|
56 |
+
MAX_SEQ_LEN = 200
|
57 |
+
new_sents = []
|
58 |
+
|
59 |
+
for sent in sents:
|
60 |
+
words = sent.split()
|
61 |
+
num_words = len(words)
|
62 |
+
if num_words > MAX_SEQ_LEN:
|
63 |
+
print_str = " ".join(words[:5]) + " .... " + " ".join(words[-5:])
|
64 |
+
sent = " ".join(words[:MAX_SEQ_LEN])
|
65 |
+
print(
|
66 |
+
f"WARNING: Sentence {print_str} truncated to 200 tokens as it exceeds maximum length limit"
|
67 |
+
)
|
68 |
+
|
69 |
+
new_sents.append(sent)
|
70 |
+
return new_sents
|
71 |
+
|
72 |
+
|
73 |
+
class Model:
|
74 |
+
def __init__(self, expdir):
|
75 |
+
self.expdir = expdir
|
76 |
+
self.en_tok = MosesTokenizer(lang="en")
|
77 |
+
self.en_normalizer = MosesPunctNormalizer()
|
78 |
+
self.en_detok = MosesDetokenizer(lang="en")
|
79 |
+
self.xliterator = unicode_transliterate.UnicodeIndicTransliterator()
|
80 |
+
print("Initializing vocab and bpe")
|
81 |
+
self.vocabulary = read_vocabulary(
|
82 |
+
codecs.open(f"{expdir}/vocab/vocab.SRC", encoding="utf-8"), 5
|
83 |
+
)
|
84 |
+
self.bpe = BPE(
|
85 |
+
codecs.open(f"{expdir}/vocab/bpe_codes.32k.SRC", encoding="utf-8"),
|
86 |
+
-1,
|
87 |
+
"@@",
|
88 |
+
self.vocabulary,
|
89 |
+
None,
|
90 |
+
)
|
91 |
+
|
92 |
+
print("Initializing model for translation")
|
93 |
+
# initialize the model
|
94 |
+
self.translator = Translator(
|
95 |
+
f"{expdir}/final_bin", f"{expdir}/model/checkpoint_best.pt", batch_size=100
|
96 |
+
)
|
97 |
+
|
98 |
+
# translate a batch of sentences from src_lang to tgt_lang
|
99 |
+
def batch_translate(self, batch, src_lang, tgt_lang):
|
100 |
+
|
101 |
+
assert isinstance(batch, list)
|
102 |
+
preprocessed_sents = self.preprocess(batch, lang=src_lang)
|
103 |
+
bpe_sents = self.apply_bpe(preprocessed_sents)
|
104 |
+
tagged_sents = apply_lang_tags(bpe_sents, src_lang, tgt_lang)
|
105 |
+
tagged_sents = truncate_long_sentences(tagged_sents)
|
106 |
+
|
107 |
+
translations = self.translator.translate(tagged_sents)
|
108 |
+
postprocessed_sents = self.postprocess(translations, tgt_lang)
|
109 |
+
|
110 |
+
return postprocessed_sents
|
111 |
+
|
112 |
+
# translate a paragraph from src_lang to tgt_lang
|
113 |
+
def translate_paragraph(self, paragraph, src_lang, tgt_lang):
|
114 |
+
|
115 |
+
assert isinstance(paragraph, str)
|
116 |
+
sents = split_sentences(paragraph, src_lang)
|
117 |
+
|
118 |
+
postprocessed_sents = self.batch_translate(sents, src_lang, tgt_lang)
|
119 |
+
|
120 |
+
translated_paragraph = " ".join(postprocessed_sents)
|
121 |
+
|
122 |
+
return translated_paragraph
|
123 |
+
|
124 |
+
def preprocess_sent(self, sent, normalizer, lang):
|
125 |
+
if lang == "en":
|
126 |
+
return " ".join(
|
127 |
+
self.en_tok.tokenize(
|
128 |
+
self.en_normalizer.normalize(sent.strip()), escape=False
|
129 |
+
)
|
130 |
+
)
|
131 |
+
else:
|
132 |
+
# line = indic_detokenize.trivial_detokenize(line.strip(), lang)
|
133 |
+
return unicode_transliterate.UnicodeIndicTransliterator.transliterate(
|
134 |
+
" ".join(
|
135 |
+
indic_tokenize.trivial_tokenize(
|
136 |
+
normalizer.normalize(sent.strip()), lang
|
137 |
+
)
|
138 |
+
),
|
139 |
+
lang,
|
140 |
+
"hi",
|
141 |
+
).replace(" ् ", "्")
|
142 |
+
|
143 |
+
def preprocess(self, sents, lang):
|
144 |
+
"""
|
145 |
+
Normalize, tokenize and script convert(for Indic)
|
146 |
+
return number of sentences input file
|
147 |
+
|
148 |
+
"""
|
149 |
+
|
150 |
+
if lang == "en":
|
151 |
+
|
152 |
+
# processed_sents = Parallel(n_jobs=-1, backend="multiprocessing")(
|
153 |
+
# delayed(preprocess_line)(line, None, lang) for line in tqdm(sents, total=num_lines)
|
154 |
+
# )
|
155 |
+
processed_sents = [
|
156 |
+
self.preprocess_sent(line, None, lang) for line in tqdm(sents)
|
157 |
+
]
|
158 |
+
|
159 |
+
else:
|
160 |
+
normfactory = indic_normalize.IndicNormalizerFactory()
|
161 |
+
normalizer = normfactory.get_normalizer(lang)
|
162 |
+
|
163 |
+
# processed_sents = Parallel(n_jobs=-1, backend="multiprocessing")(
|
164 |
+
# delayed(preprocess_line)(line, normalizer, lang) for line in tqdm(infile, total=num_lines)
|
165 |
+
# )
|
166 |
+
processed_sents = [
|
167 |
+
self.preprocess_sent(line, normalizer, lang) for line in tqdm(sents)
|
168 |
+
]
|
169 |
+
|
170 |
+
return processed_sents
|
171 |
+
|
172 |
+
def postprocess(self, sents, lang, common_lang="hi"):
|
173 |
+
"""
|
174 |
+
parse fairseq interactive output, convert script back to native Indic script (in case of Indic languages) and detokenize.
|
175 |
+
|
176 |
+
infname: fairseq log file
|
177 |
+
outfname: output file of translation (sentences not translated contain the dummy string 'DUMMY_OUTPUT'
|
178 |
+
input_size: expected number of output sentences
|
179 |
+
lang: language
|
180 |
+
"""
|
181 |
+
postprocessed_sents = []
|
182 |
+
|
183 |
+
if lang == "en":
|
184 |
+
for sent in sents:
|
185 |
+
# outfile.write(en_detok.detokenize(sent.split(" ")) + "\n")
|
186 |
+
postprocessed_sents.append(self.en_detok.detokenize(sent.split(" ")))
|
187 |
+
else:
|
188 |
+
for sent in sents:
|
189 |
+
outstr = indic_detokenize.trivial_detokenize(
|
190 |
+
self.xliterator.transliterate(sent, common_lang, lang), lang
|
191 |
+
)
|
192 |
+
# outfile.write(outstr + "\n")
|
193 |
+
postprocessed_sents.append(outstr)
|
194 |
+
return postprocessed_sents
|
195 |
+
|
196 |
+
def apply_bpe(self, sents):
|
197 |
+
|
198 |
+
return [self.bpe.process_line(sent) for sent in sents]
|
legacy/apply_bpe_test_valid_notag.sh
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
expdir=$1 # EXPDIR
|
4 |
+
org_data_dir=$2
|
5 |
+
langs=$3
|
6 |
+
|
7 |
+
#`dirname $0`/env.sh
|
8 |
+
SUBWORD_NMT_DIR="subword-nmt"
|
9 |
+
echo "Apply to each language"
|
10 |
+
|
11 |
+
for dset in `echo test dev`
|
12 |
+
do
|
13 |
+
echo $dset
|
14 |
+
|
15 |
+
in_dset_dir="$org_data_dir/$dset"
|
16 |
+
out_dset_dir="$expdir/bpe/$dset"
|
17 |
+
|
18 |
+
for lang in $langs
|
19 |
+
do
|
20 |
+
|
21 |
+
echo Apply BPE for $dset "-" $lang
|
22 |
+
|
23 |
+
mkdir -p $out_dset_dir
|
24 |
+
|
25 |
+
python $SUBWORD_NMT_DIR/subword_nmt/apply_bpe.py \
|
26 |
+
-c $expdir/vocab/bpe_codes.32k.SRC_TGT \
|
27 |
+
--vocabulary $expdir/vocab/vocab.SRC \
|
28 |
+
--vocabulary-threshold 5 \
|
29 |
+
< $in_dset_dir/$dset.$lang \
|
30 |
+
> $out_dset_dir/$dset.$lang
|
31 |
+
|
32 |
+
done
|
33 |
+
done
|
legacy/apply_bpe_train_notag.sh
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
expdir=$1 # EXPDIR
|
4 |
+
|
5 |
+
#`dirname $0`/env.sh
|
6 |
+
SUBWORD_NMT_DIR="subword-nmt"
|
7 |
+
|
8 |
+
data_dir="$expdir/data"
|
9 |
+
train_file=$data_dir/train
|
10 |
+
bpe_file=$expdir/bpe/train/train
|
11 |
+
|
12 |
+
mkdir -p $expdir/bpe/train
|
13 |
+
|
14 |
+
echo "Apply to SRC corpus"
|
15 |
+
|
16 |
+
python $SUBWORD_NMT_DIR/subword_nmt/apply_bpe.py \
|
17 |
+
-c $expdir/vocab/bpe_codes.32k.SRC_TGT \
|
18 |
+
--vocabulary $expdir/vocab/vocab.SRC \
|
19 |
+
--vocabulary-threshold 5 \
|
20 |
+
--num-workers "-1" \
|
21 |
+
< $train_file.SRC \
|
22 |
+
> $bpe_file.SRC
|
23 |
+
|
24 |
+
echo "Apply to TGT corpus"
|
25 |
+
|
26 |
+
python $SUBWORD_NMT_DIR/subword_nmt/apply_bpe.py \
|
27 |
+
-c $expdir/vocab/bpe_codes.32k.SRC_TGT \
|
28 |
+
--vocabulary $expdir/vocab/vocab.TGT \
|
29 |
+
--vocabulary-threshold 5 \
|
30 |
+
--num-workers "-1" \
|
31 |
+
< $train_file.TGT \
|
32 |
+
> $bpe_file.TGT
|
33 |
+
|
legacy/env.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
export SRC=''
|
3 |
+
|
4 |
+
## Python env directory where fairseq is installed
|
5 |
+
export PYTHON_ENV=''
|
6 |
+
|
7 |
+
export SUBWORD_NMT_DIR=''
|
8 |
+
export INDIC_RESOURCES_PATH=''
|
9 |
+
export INDIC_NLP_HOME=''
|
10 |
+
|
11 |
+
export CUDA_HOME=''
|
12 |
+
|
13 |
+
export PATH=$CUDA_HOME/bin:$INDIC_NLP_HOME:$PATH
|
14 |
+
export LD_LIBRARY_PATH=$CUDA_HOME/lib64
|
15 |
+
|
16 |
+
# set environment variable to control GPUS visible to the application
|
17 |
+
#export CUDA_VISIBLE_DEVICES="'
|
legacy/indictrans_workflow.ipynb
ADDED
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os\n",
|
10 |
+
"import random\n",
|
11 |
+
"from tqdm.notebook import tqdm\n",
|
12 |
+
"from sacremoses import MosesPunctNormalizer\n",
|
13 |
+
"from sacremoses import MosesTokenizer\n",
|
14 |
+
"from sacremoses import MosesDetokenizer\n",
|
15 |
+
"from collections import defaultdict\n",
|
16 |
+
"import sacrebleu"
|
17 |
+
]
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"cell_type": "code",
|
21 |
+
"execution_count": null,
|
22 |
+
"metadata": {},
|
23 |
+
"outputs": [],
|
24 |
+
"source": [
|
25 |
+
"# The path to the local git repo for Indic NLP library\n",
|
26 |
+
"INDIC_NLP_LIB_HOME=\"\"\n",
|
27 |
+
"\n",
|
28 |
+
"# The path to the local git repo for Indic NLP Resources\n",
|
29 |
+
"INDIC_NLP_RESOURCES=\"\"\n",
|
30 |
+
"\n",
|
31 |
+
"import sys\n",
|
32 |
+
"sys.path.append(r'{}'.format(INDIC_NLP_LIB_HOME))\n",
|
33 |
+
"\n",
|
34 |
+
"from indicnlp import common\n",
|
35 |
+
"common.set_resources_path(INDIC_NLP_RESOURCES)\n",
|
36 |
+
"\n",
|
37 |
+
"from indicnlp import loader\n",
|
38 |
+
"loader.load()"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "code",
|
43 |
+
"execution_count": null,
|
44 |
+
"metadata": {},
|
45 |
+
"outputs": [],
|
46 |
+
"source": [
|
47 |
+
"import indicnlp\n",
|
48 |
+
"from indicnlp.tokenize import indic_tokenize\n",
|
49 |
+
"from indicnlp.tokenize import indic_detokenize\n",
|
50 |
+
"from indicnlp.normalize import indic_normalize\n",
|
51 |
+
"from indicnlp.transliterate import unicode_transliterate"
|
52 |
+
]
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"cell_type": "code",
|
56 |
+
"execution_count": null,
|
57 |
+
"metadata": {},
|
58 |
+
"outputs": [],
|
59 |
+
"source": [
|
60 |
+
"LANGS=[\n",
|
61 |
+
" \"bn\",\n",
|
62 |
+
" \"gu\",\n",
|
63 |
+
" \"hi\",\n",
|
64 |
+
" \"kn\",\n",
|
65 |
+
" \"ml\",\n",
|
66 |
+
" \"mr\",\n",
|
67 |
+
" \"or\",\n",
|
68 |
+
" \"pa\",\n",
|
69 |
+
" \"ta\",\n",
|
70 |
+
" \"te\", \n",
|
71 |
+
"]"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"cell_type": "code",
|
76 |
+
"execution_count": null,
|
77 |
+
"metadata": {},
|
78 |
+
"outputs": [],
|
79 |
+
"source": [
|
80 |
+
"def preprocess(infname,outfname,lang):\n",
|
81 |
+
" \"\"\"\n",
|
82 |
+
" Preparing each corpus file: \n",
|
83 |
+
" - Normalization\n",
|
84 |
+
" - Tokenization \n",
|
85 |
+
" - Script coversion to Devanagari for Indic scripts\n",
|
86 |
+
" \"\"\"\n",
|
87 |
+
" \n",
|
88 |
+
" ### reading \n",
|
89 |
+
" with open(infname,'r',encoding='utf-8') as infile, \\\n",
|
90 |
+
" open(outfname,'w',encoding='utf-8') as outfile:\n",
|
91 |
+
" \n",
|
92 |
+
" if lang=='en':\n",
|
93 |
+
" en_tok=MosesTokenizer(lang='en')\n",
|
94 |
+
" en_normalizer = MosesPunctNormalizer()\n",
|
95 |
+
" for line in tqdm(infile): \n",
|
96 |
+
" outline=' '.join(\n",
|
97 |
+
" en_tok.tokenize( \n",
|
98 |
+
" en_normalizer.normalize(line.strip()), \n",
|
99 |
+
" escape=False ) )\n",
|
100 |
+
" outfile.write(outline+'\\n')\n",
|
101 |
+
" \n",
|
102 |
+
" else:\n",
|
103 |
+
" normfactory=indic_normalize.IndicNormalizerFactory()\n",
|
104 |
+
" normalizer=normfactory.get_normalizer(lang)\n",
|
105 |
+
" for line in tqdm(infile): \n",
|
106 |
+
" outline=unicode_transliterate.UnicodeIndicTransliterator.transliterate(\n",
|
107 |
+
" ' '.join(\n",
|
108 |
+
" indic_tokenize.trivial_tokenize(\n",
|
109 |
+
" normalizer.normalize(line.strip()), lang) ), lang, 'hi').replace(' ् ','्')\n",
|
110 |
+
"\n",
|
111 |
+
"\n",
|
112 |
+
" outfile.write(outline+'\\n')"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"cell_type": "code",
|
117 |
+
"execution_count": null,
|
118 |
+
"metadata": {},
|
119 |
+
"outputs": [],
|
120 |
+
"source": [
|
121 |
+
"def add_token(sent, tag_infos):\n",
|
122 |
+
" \"\"\" add special tokens specified by tag_infos to each element in list\n",
|
123 |
+
"\n",
|
124 |
+
" tag_infos: list of tuples (tag_type,tag)\n",
|
125 |
+
"\n",
|
126 |
+
" each tag_info results in a token of the form: __{tag_type}__{tag}__\n",
|
127 |
+
"\n",
|
128 |
+
" \"\"\"\n",
|
129 |
+
"\n",
|
130 |
+
" tokens=[]\n",
|
131 |
+
" for tag_type, tag in tag_infos:\n",
|
132 |
+
" token = '__' + tag_type + '__' + tag + '__'\n",
|
133 |
+
" tokens.append(token)\n",
|
134 |
+
"\n",
|
135 |
+
" return ' '.join(tokens) + ' ' + sent \n",
|
136 |
+
"\n",
|
137 |
+
"\n",
|
138 |
+
"def concat_data(data_dir, outdir, lang_pair_list, out_src_lang='SRC', out_trg_lang='TGT'):\n",
|
139 |
+
" \"\"\"\n",
|
140 |
+
" data_dir: input dir, contains directories for language pairs named l1-l2\n",
|
141 |
+
" \"\"\"\n",
|
142 |
+
" os.makedirs(outdir,exist_ok=True)\n",
|
143 |
+
"\n",
|
144 |
+
" out_src_fname='{}/train.{}'.format(outdir,out_src_lang)\n",
|
145 |
+
" out_trg_fname='{}/train.{}'.format(outdir,out_trg_lang)\n",
|
146 |
+
"# out_meta_fname='{}/metadata.txt'.format(outdir)\n",
|
147 |
+
"\n",
|
148 |
+
" print()\n",
|
149 |
+
" print(out_src_fname)\n",
|
150 |
+
" print(out_trg_fname)\n",
|
151 |
+
"# print(out_meta_fname)\n",
|
152 |
+
"\n",
|
153 |
+
" ### concatenate train data \n",
|
154 |
+
" if os.path.isfile(out_src_fname):\n",
|
155 |
+
" os.unlink(out_src_fname)\n",
|
156 |
+
" if os.path.isfile(out_trg_fname):\n",
|
157 |
+
" os.unlink(out_trg_fname)\n",
|
158 |
+
"# if os.path.isfile(out_meta_fname):\n",
|
159 |
+
"# os.unlink(out_meta_fname)\n",
|
160 |
+
"\n",
|
161 |
+
" for src_lang, trg_lang in tqdm(lang_pair_list):\n",
|
162 |
+
" print('src: {}, tgt:{}'.format(src_lang,trg_lang)) \n",
|
163 |
+
"\n",
|
164 |
+
" in_src_fname='{}/{}-{}/train.{}'.format(data_dir,src_lang,trg_lang,src_lang)\n",
|
165 |
+
" in_trg_fname='{}/{}-{}/train.{}'.format(data_dir,src_lang,trg_lang,trg_lang)\n",
|
166 |
+
"\n",
|
167 |
+
" print(in_src_fname)\n",
|
168 |
+
" os.system('cat {} >> {}'.format(in_src_fname,out_src_fname))\n",
|
169 |
+
"\n",
|
170 |
+
" print(in_trg_fname)\n",
|
171 |
+
" os.system('cat {} >> {}'.format(in_trg_fname,out_trg_fname)) \n",
|
172 |
+
" \n",
|
173 |
+
" \n",
|
174 |
+
"# with open('{}/lang_pairs.txt'.format(outdir),'w',encoding='utf-8') as lpfile: \n",
|
175 |
+
"# lpfile.write('\\n'.join( [ '-'.join(x) for x in lang_pair_list ] ))\n",
|
176 |
+
" \n",
|
177 |
+
" corpus_stats(data_dir, outdir, lang_pair_list)\n",
|
178 |
+
" \n",
|
179 |
+
"def corpus_stats(data_dir, outdir, lang_pair_list):\n",
|
180 |
+
" \"\"\"\n",
|
181 |
+
" data_dir: input dir, contains directories for language pairs named l1-l2\n",
|
182 |
+
" \"\"\"\n",
|
183 |
+
"\n",
|
184 |
+
" with open('{}/lang_pairs.txt'.format(outdir),'w',encoding='utf-8') as lpfile: \n",
|
185 |
+
"\n",
|
186 |
+
" for src_lang, trg_lang in tqdm(lang_pair_list):\n",
|
187 |
+
" print('src: {}, tgt:{}'.format(src_lang,trg_lang)) \n",
|
188 |
+
"\n",
|
189 |
+
" in_src_fname='{}/{}-{}/train.{}'.format(data_dir,src_lang,trg_lang,src_lang)\n",
|
190 |
+
" # in_trg_fname='{}/{}-{}/train.{}'.format(data_dir,src_lang,trg_lang,trg_lang)\n",
|
191 |
+
"\n",
|
192 |
+
" print(in_src_fname)\n",
|
193 |
+
" corpus_size=0\n",
|
194 |
+
" with open(in_src_fname,'r',encoding='utf-8') as infile:\n",
|
195 |
+
" corpus_size=sum(map(lambda x:1,infile))\n",
|
196 |
+
" \n",
|
197 |
+
" lpfile.write('{}\\t{}\\t{}\\n'.format(src_lang,trg_lang,corpus_size))\n",
|
198 |
+
" \n",
|
199 |
+
"def generate_lang_tag_iterator(infname):\n",
|
200 |
+
" with open(infname,'r',encoding='utf-8') as infile:\n",
|
201 |
+
" for line in infile:\n",
|
202 |
+
" src,tgt,count=line.strip().split('\\t')\n",
|
203 |
+
" count=int(count)\n",
|
204 |
+
" for _ in range(count):\n",
|
205 |
+
" yield (src,tgt) "
|
206 |
+
]
|
207 |
+
},
|
208 |
+
{
|
209 |
+
"cell_type": "code",
|
210 |
+
"execution_count": null,
|
211 |
+
"metadata": {},
|
212 |
+
"outputs": [],
|
213 |
+
"source": [
|
214 |
+
"#### directory containing all experiments \n",
|
215 |
+
"## one directory per experiment \n",
|
216 |
+
"EXPBASEDIR=''\n",
|
217 |
+
"\n",
|
218 |
+
"### directory containing data\n",
|
219 |
+
"## contains 3 directories: train test dev\n",
|
220 |
+
"## train directory structure: \n",
|
221 |
+
"## - There is one directory for each language pair\n",
|
222 |
+
"## - Directory naming convention lang1-lang2 (you need another directory/softlink for lang2-lang1)\n",
|
223 |
+
"## - Each directory contains 6 files: {train,test,dev}.{lang1,lang2}\n",
|
224 |
+
"## test & dev directory structure \n",
|
225 |
+
"## - test: contains files {test.l1,test.l2,test.l3} - assumes parallel test files like the wat2021 dataset\n",
|
226 |
+
"## - valid: contains files {dev.l1,dev.l2,dev.l3} - assumes parallel test files like the wat2021 dataset\n",
|
227 |
+
"## All files are tokenized\n",
|
228 |
+
"ORG_DATA_DIR='{d}/consolidated_unique_preprocessed'.format(d=BASEDIR)\n",
|
229 |
+
"\n"
|
230 |
+
]
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"cell_type": "markdown",
|
234 |
+
"metadata": {},
|
235 |
+
"source": [
|
236 |
+
"# Exp2 (M2O)\n",
|
237 |
+
"\n",
|
238 |
+
"- All *-en "
|
239 |
+
]
|
240 |
+
},
|
241 |
+
{
|
242 |
+
"cell_type": "markdown",
|
243 |
+
"metadata": {},
|
244 |
+
"source": [
|
245 |
+
"**Params**"
|
246 |
+
]
|
247 |
+
},
|
248 |
+
{
|
249 |
+
"cell_type": "code",
|
250 |
+
"execution_count": null,
|
251 |
+
"metadata": {},
|
252 |
+
"outputs": [],
|
253 |
+
"source": [
|
254 |
+
"expname='exp2_m2o_baseline'\n",
|
255 |
+
"expdir='{}/{}'.format(EXPBASEDIR,expname)\n",
|
256 |
+
"\n",
|
257 |
+
"lang_pair_list=[]\n",
|
258 |
+
"for lang in LANGS: \n",
|
259 |
+
" lang_pair_list.append([lang,'en'])"
|
260 |
+
]
|
261 |
+
},
|
262 |
+
{
|
263 |
+
"cell_type": "markdown",
|
264 |
+
"metadata": {},
|
265 |
+
"source": [
|
266 |
+
"**Create Train Corpus**"
|
267 |
+
]
|
268 |
+
},
|
269 |
+
{
|
270 |
+
"cell_type": "code",
|
271 |
+
"execution_count": null,
|
272 |
+
"metadata": {},
|
273 |
+
"outputs": [],
|
274 |
+
"source": [
|
275 |
+
"indir='{}/train'.format(ORG_DATA_DIR)\n",
|
276 |
+
"outdir='{}/data'.format(expdir)\n",
|
277 |
+
"\n",
|
278 |
+
"# print(lang_pair_list)\n",
|
279 |
+
"concat_data(indir,outdir,lang_pair_list)"
|
280 |
+
]
|
281 |
+
},
|
282 |
+
{
|
283 |
+
"cell_type": "markdown",
|
284 |
+
"metadata": {},
|
285 |
+
"source": [
|
286 |
+
"**Learn BPE**"
|
287 |
+
]
|
288 |
+
},
|
289 |
+
{
|
290 |
+
"cell_type": "code",
|
291 |
+
"execution_count": null,
|
292 |
+
"metadata": {},
|
293 |
+
"outputs": [],
|
294 |
+
"source": [
|
295 |
+
"!echo ./learn_bpe.sh {expdir}"
|
296 |
+
]
|
297 |
+
},
|
298 |
+
{
|
299 |
+
"cell_type": "code",
|
300 |
+
"execution_count": null,
|
301 |
+
"metadata": {},
|
302 |
+
"outputs": [],
|
303 |
+
"source": [
|
304 |
+
"!echo ./apply_bpe_train_notag.sh {expdir}"
|
305 |
+
]
|
306 |
+
},
|
307 |
+
{
|
308 |
+
"cell_type": "code",
|
309 |
+
"execution_count": null,
|
310 |
+
"metadata": {},
|
311 |
+
"outputs": [],
|
312 |
+
"source": [
|
313 |
+
"!echo ./apply_bpe_test_valid_notag.sh {expdir} {ORG_DATA_DIR} {'\"'+' '.join(LANGS+['en'])+'\"'}"
|
314 |
+
]
|
315 |
+
},
|
316 |
+
{
|
317 |
+
"cell_type": "markdown",
|
318 |
+
"metadata": {},
|
319 |
+
"source": [
|
320 |
+
"**Add language tags to train**"
|
321 |
+
]
|
322 |
+
},
|
323 |
+
{
|
324 |
+
"cell_type": "code",
|
325 |
+
"execution_count": null,
|
326 |
+
"metadata": {},
|
327 |
+
"outputs": [],
|
328 |
+
"source": [
|
329 |
+
"dset='train' \n",
|
330 |
+
"\n",
|
331 |
+
"src_fname='{expdir}/bpe/train/{dset}.SRC'.format(expdir=expdir,dset=dset)\n",
|
332 |
+
"tgt_fname='{expdir}/bpe/train/{dset}.TGT'.format(expdir=expdir,dset=dset)\n",
|
333 |
+
"meta_fname='{expdir}/data/lang_pairs.txt'.format(expdir=expdir,dset=dset)\n",
|
334 |
+
" \n",
|
335 |
+
"out_src_fname='{expdir}/final/{dset}.SRC'.format(expdir=expdir,dset=dset)\n",
|
336 |
+
"out_tgt_fname='{expdir}/final/{dset}.TGT'.format(expdir=expdir,dset=dset)\n",
|
337 |
+
"\n",
|
338 |
+
"lang_tag_iterator=generate_lang_tag_iterator(meta_fname)\n",
|
339 |
+
"\n",
|
340 |
+
"print(expdir)\n",
|
341 |
+
"os.makedirs('{expdir}/final'.format(expdir=expdir),exist_ok=True)\n",
|
342 |
+
"\n",
|
343 |
+
"with open(src_fname,'r',encoding='utf-8') as srcfile, \\\n",
|
344 |
+
" open(tgt_fname,'r',encoding='utf-8') as tgtfile, \\\n",
|
345 |
+
" open(out_src_fname,'w',encoding='utf-8') as outsrcfile, \\\n",
|
346 |
+
" open(out_tgt_fname,'w',encoding='utf-8') as outtgtfile: \n",
|
347 |
+
"\n",
|
348 |
+
" for (l1,l2), src_sent, tgt_sent in tqdm(zip(lang_tag_iterator, srcfile, tgtfile)):\n",
|
349 |
+
" outsrcfile.write(add_token(src_sent.strip(),[('src',l1),('tgt',l2)]) + '\\n' )\n",
|
350 |
+
" outtgtfile.write(tgt_sent.strip()+'\\n')"
|
351 |
+
]
|
352 |
+
},
|
353 |
+
{
|
354 |
+
"cell_type": "markdown",
|
355 |
+
"metadata": {},
|
356 |
+
"source": [
|
357 |
+
"**Add language tags to valid**\n",
|
358 |
+
"\n",
|
359 |
+
"- add language tags, create parallel corpus\n",
|
360 |
+
"- sample 20\\% for validation set \n",
|
361 |
+
"- Create final validation set"
|
362 |
+
]
|
363 |
+
},
|
364 |
+
{
|
365 |
+
"cell_type": "code",
|
366 |
+
"execution_count": null,
|
367 |
+
"metadata": {},
|
368 |
+
"outputs": [],
|
369 |
+
"source": [
|
370 |
+
"dset='dev' \n",
|
371 |
+
"out_src_fname='{expdir}/final/{dset}.SRC'.format(\n",
|
372 |
+
" expdir=expdir,dset=dset)\n",
|
373 |
+
"out_tgt_fname='{expdir}/final/{dset}.TGT'.format(\n",
|
374 |
+
" expdir=expdir,dset=dset)\n",
|
375 |
+
"\n",
|
376 |
+
"os.makedirs('{expdir}/final'.format(expdir=expdir),exist_ok=True)\n",
|
377 |
+
"\n",
|
378 |
+
"print('Processing validation files') \n",
|
379 |
+
"consolidated_dset=[]\n",
|
380 |
+
"for l1, l2 in tqdm(lang_pair_list):\n",
|
381 |
+
" src_fname='{expdir}/bpe/{dset}/{dset}.{lang}'.format(\n",
|
382 |
+
" expdir=expdir,dset=dset,lang=l1)\n",
|
383 |
+
" tgt_fname='{expdir}/bpe/{dset}/{dset}.{lang}'.format(\n",
|
384 |
+
" expdir=expdir,dset=dset,lang=l2)\n",
|
385 |
+
"# print(src_fname)\n",
|
386 |
+
"# print(os.path.exists(src_fname))\n",
|
387 |
+
" with open(src_fname,'r',encoding='utf-8') as srcfile, \\\n",
|
388 |
+
" open(tgt_fname,'r',encoding='utf-8') as tgtfile:\n",
|
389 |
+
" for src_sent, tgt_sent in zip(srcfile,tgtfile):\n",
|
390 |
+
" consolidated_dset.append(\n",
|
391 |
+
" ( add_token(src_sent.strip(),[('src',l1),('tgt',l2)]),\n",
|
392 |
+
" tgt_sent.strip() )\n",
|
393 |
+
" )\n",
|
394 |
+
"\n",
|
395 |
+
"print('Create validation set') \n",
|
396 |
+
"random.shuffle(consolidated_dset)\n",
|
397 |
+
"final_set=consolidated_dset[:len(consolidated_dset)//5] \n",
|
398 |
+
"\n",
|
399 |
+
"print('Original set size: {}'.format(len(consolidated_dset))) \n",
|
400 |
+
"print('Sampled set size: {}'.format(len(final_set))) \n",
|
401 |
+
"\n",
|
402 |
+
"print('Write validation set')\n",
|
403 |
+
"\n",
|
404 |
+
"with open(out_src_fname,'w',encoding='utf-8') as srcfile, \\\n",
|
405 |
+
" open(out_tgt_fname,'w',encoding='utf-8') as tgtfile:\n",
|
406 |
+
" for src_sent, tgt_sent in final_set: \n",
|
407 |
+
" srcfile.write(src_sent+'\\n')\n",
|
408 |
+
" tgtfile.write(tgt_sent+'\\n')\n"
|
409 |
+
]
|
410 |
+
},
|
411 |
+
{
|
412 |
+
"cell_type": "markdown",
|
413 |
+
"metadata": {},
|
414 |
+
"source": [
|
415 |
+
"**Add language tags to test**\n",
|
416 |
+
"\n",
|
417 |
+
"- add language tags, create parallel corpus all M2O language pairs \n",
|
418 |
+
"- Create final test set"
|
419 |
+
]
|
420 |
+
},
|
421 |
+
{
|
422 |
+
"cell_type": "code",
|
423 |
+
"execution_count": null,
|
424 |
+
"metadata": {},
|
425 |
+
"outputs": [],
|
426 |
+
"source": [
|
427 |
+
"dset='test' \n",
|
428 |
+
"out_src_fname='{expdir}/final/{dset}.SRC'.format(\n",
|
429 |
+
" expdir=expdir,dset=dset)\n",
|
430 |
+
"out_tgt_fname='{expdir}/final/{dset}.TGT'.format(\n",
|
431 |
+
" expdir=expdir,dset=dset)\n",
|
432 |
+
"\n",
|
433 |
+
"os.makedirs('{expdir}/final'.format(expdir=expdir),exist_ok=True)\n",
|
434 |
+
"\n",
|
435 |
+
"print('Processing test files') \n",
|
436 |
+
"consolidated_dset=[]\n",
|
437 |
+
"for l1, l2 in tqdm(lang_pair_list):\n",
|
438 |
+
" src_fname='{expdir}/bpe/{dset}/{dset}.{lang}'.format(\n",
|
439 |
+
" expdir=expdir,dset=dset,lang=l1)\n",
|
440 |
+
" tgt_fname='{expdir}/bpe/{dset}/{dset}.{lang}'.format(\n",
|
441 |
+
" expdir=expdir,dset=dset,lang=l2)\n",
|
442 |
+
"# print(src_fname)\n",
|
443 |
+
"# print(os.path.exists(src_fname))\n",
|
444 |
+
" with open(src_fname,'r',encoding='utf-8') as srcfile, \\\n",
|
445 |
+
" open(tgt_fname,'r',encoding='utf-8') as tgtfile:\n",
|
446 |
+
" for src_sent, tgt_sent in zip(srcfile,tgtfile):\n",
|
447 |
+
" consolidated_dset.append(\n",
|
448 |
+
" ( add_token(src_sent.strip(),[('src',l1),('tgt',l2)]),\n",
|
449 |
+
" tgt_sent.strip() )\n",
|
450 |
+
" )\n",
|
451 |
+
"\n",
|
452 |
+
"print('Final set size: {}'.format(len(consolidated_dset))) \n",
|
453 |
+
" \n",
|
454 |
+
"print('Write test set')\n",
|
455 |
+
"print('testset truncated')\n",
|
456 |
+
"\n",
|
457 |
+
"with open(out_src_fname,'w',encoding='utf-8') as srcfile, \\\n",
|
458 |
+
" open(out_tgt_fname,'w',encoding='utf-8') as tgtfile:\n",
|
459 |
+
" for lno, (src_sent, tgt_sent) in enumerate(consolidated_dset,1):\n",
|
460 |
+
" \n",
|
461 |
+
" s=src_sent.strip().split(' ')\n",
|
462 |
+
" t=tgt_sent.strip().split(' ')\n",
|
463 |
+
" \n",
|
464 |
+
" if len(s) > 200 or len(t) > 200:\n",
|
465 |
+
" print('exp: {}, pair: ({},{}), lno: {}: lens: ({},{})'.format(expname,l1,l2,lno,len(s),len(t))) \n",
|
466 |
+
" \n",
|
467 |
+
" src_sent=' '.join( s[:min(len(s),200)] )\n",
|
468 |
+
" tgt_sent=' '.join( t[:min(len(t),200)] )\n",
|
469 |
+
" \n",
|
470 |
+
" srcfile.write(src_sent+'\\n')\n",
|
471 |
+
" tgtfile.write(tgt_sent+'\\n')"
|
472 |
+
]
|
473 |
+
},
|
474 |
+
{
|
475 |
+
"cell_type": "markdown",
|
476 |
+
"metadata": {},
|
477 |
+
"source": [
|
478 |
+
"**Binarize data**"
|
479 |
+
]
|
480 |
+
},
|
481 |
+
{
|
482 |
+
"cell_type": "code",
|
483 |
+
"execution_count": null,
|
484 |
+
"metadata": {},
|
485 |
+
"outputs": [],
|
486 |
+
"source": [
|
487 |
+
"!echo ./binarize_training_exp.sh {expdir} SRC TGT"
|
488 |
+
]
|
489 |
+
},
|
490 |
+
{
|
491 |
+
"cell_type": "markdown",
|
492 |
+
"metadata": {},
|
493 |
+
"source": [
|
494 |
+
"**Training Command**"
|
495 |
+
]
|
496 |
+
},
|
497 |
+
{
|
498 |
+
"cell_type": "code",
|
499 |
+
"execution_count": null,
|
500 |
+
"metadata": {},
|
501 |
+
"outputs": [],
|
502 |
+
"source": [
|
503 |
+
"%%bash \n",
|
504 |
+
"\n",
|
505 |
+
"python train.py {expdir}/final_bin \\\n",
|
506 |
+
" --arch transformer \\\n",
|
507 |
+
" --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 1.0 \\\n",
|
508 |
+
" --lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \\\n",
|
509 |
+
" --dropout 0.2 \\\n",
|
510 |
+
" --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \\\n",
|
511 |
+
" --max-tokens 8192 \\\n",
|
512 |
+
" --max-update 1000000 \\\n",
|
513 |
+
" --max-source-positions 200 \\\n",
|
514 |
+
" --max-target-positions 200 \\\n",
|
515 |
+
" --tensorboard-logdir {expdir}/tensorboard \\\n",
|
516 |
+
" --save-dir {expdir}/model \\\n",
|
517 |
+
" --required-batch-size-multiple 8 \\\n",
|
518 |
+
" --save-interval 1 \\\n",
|
519 |
+
" --keep-last-epochs 5 \\\n",
|
520 |
+
" --patience 5 \\\n",
|
521 |
+
" --fp16"
|
522 |
+
]
|
523 |
+
},
|
524 |
+
{
|
525 |
+
"cell_type": "markdown",
|
526 |
+
"metadata": {},
|
527 |
+
"source": [
|
528 |
+
"**Cleanup**"
|
529 |
+
]
|
530 |
+
},
|
531 |
+
{
|
532 |
+
"cell_type": "code",
|
533 |
+
"execution_count": null,
|
534 |
+
"metadata": {},
|
535 |
+
"outputs": [],
|
536 |
+
"source": [
|
537 |
+
"# os.unlink('{}')\n",
|
538 |
+
"\n",
|
539 |
+
"to_delete=[\n",
|
540 |
+
" '{expdir}/data/train.SRC'.format(expdir=expdir,dset=dset),\n",
|
541 |
+
" '{expdir}/data/train.TGT'.format(expdir=expdir,dset=dset),\n",
|
542 |
+
" '{expdir}/bpe/train/train.SRC'.format(expdir=expdir,dset=dset),\n",
|
543 |
+
" '{expdir}/bpe/train/train.TGT'.format(expdir=expdir,dset=dset),\n",
|
544 |
+
"]`\n",
|
545 |
+
"\n",
|
546 |
+
"for fname in to_delete:\n",
|
547 |
+
" os.unlink(fname)"
|
548 |
+
]
|
549 |
+
},
|
550 |
+
{
|
551 |
+
"cell_type": "markdown",
|
552 |
+
"metadata": {},
|
553 |
+
"source": [
|
554 |
+
"**Evaluation**"
|
555 |
+
]
|
556 |
+
},
|
557 |
+
{
|
558 |
+
"cell_type": "code",
|
559 |
+
"execution_count": null,
|
560 |
+
"metadata": {},
|
561 |
+
"outputs": [],
|
562 |
+
"source": [
|
563 |
+
"dset='test' \n",
|
564 |
+
"consolidated_testoutput_fname='{expdir}/evaluations/test/default/test.SRC_TGT.TGT'.format(expdir=expdir)\n",
|
565 |
+
"consolidated_testoutput_log_fname='{}.log'.format(consolidated_testoutput_fname)\n",
|
566 |
+
"metrics_fname='{expdir}/evaluations/test/default/test.metrics.tsv'.format(expdir=expdir)\n",
|
567 |
+
" \n",
|
568 |
+
"test_set_size=2390\n",
|
569 |
+
"\n",
|
570 |
+
"consolidated_testoutput=[]\n",
|
571 |
+
"with open(consolidated_testoutput_log_fname,'r',encoding='utf-8') as hypfile:\n",
|
572 |
+
" consolidated_testoutput= list(map(lambda x: x.strip(), filter(lambda x: x.startswith('H-'),hypfile) ))\n",
|
573 |
+
" consolidated_testoutput.sort(key=lambda x: int(x.split('\\t')[0].split('-')[1]))\n",
|
574 |
+
" consolidated_testoutput=[ x.split('\\t')[2] for x in consolidated_testoutput ]\n",
|
575 |
+
"\n",
|
576 |
+
"os.makedirs('{expdir}/evaluations/test/default'.format(expdir=expdir),exist_ok=True)\n",
|
577 |
+
"\n",
|
578 |
+
"with open(consolidated_testoutput_fname,'w',encoding='utf-8') as finalhypfile:\n",
|
579 |
+
" for sent in consolidated_testoutput:\n",
|
580 |
+
" finalhypfile.write(sent+'\\n')\n",
|
581 |
+
"\n",
|
582 |
+
"print('Processing test files') \n",
|
583 |
+
"with open(metrics_fname,'w',encoding='utf-8') as metrics_file: \n",
|
584 |
+
" for i, (l1, l2) in enumerate(tqdm(lang_pair_list)):\n",
|
585 |
+
"\n",
|
586 |
+
" start=i*test_set_size\n",
|
587 |
+
" end=(i+1)*test_set_size\n",
|
588 |
+
" hyps=consolidated_testoutput[start:end]\n",
|
589 |
+
" ref_fname='{expdir}/{dset}/{dset}.{lang}'.format(\n",
|
590 |
+
" expdir=ORG_DATA_DIR,dset=dset,lang=l2)\n",
|
591 |
+
"\n",
|
592 |
+
" refs=[]\n",
|
593 |
+
" with open(ref_fname,'r',encoding='utf-8') as reffile:\n",
|
594 |
+
" refs.extend(map(lambda x:x.strip(),reffile))\n",
|
595 |
+
"\n",
|
596 |
+
" assert(len(hyps)==len(refs))\n",
|
597 |
+
"\n",
|
598 |
+
" bleu=sacrebleu.corpus_bleu(hyps,[refs],tokenize='none')\n",
|
599 |
+
"\n",
|
600 |
+
" print('{} {} {} {}'.format(l1,l2,bleu.score,bleu.prec_str))\n",
|
601 |
+
" metrics_file.write('{}\\t{}\\t{}\\t{}\\t{}\\n'.format(expname,l1,l2,bleu.score,bleu.prec_str))\n",
|
602 |
+
" "
|
603 |
+
]
|
604 |
+
}
|
605 |
+
],
|
606 |
+
"metadata": {
|
607 |
+
"kernelspec": {
|
608 |
+
"display_name": "Python 3",
|
609 |
+
"language": "python",
|
610 |
+
"name": "python3"
|
611 |
+
},
|
612 |
+
"language_info": {
|
613 |
+
"codemirror_mode": {
|
614 |
+
"name": "ipython",
|
615 |
+
"version": 3
|
616 |
+
},
|
617 |
+
"file_extension": ".py",
|
618 |
+
"mimetype": "text/x-python",
|
619 |
+
"name": "python",
|
620 |
+
"nbconvert_exporter": "python",
|
621 |
+
"pygments_lexer": "ipython3",
|
622 |
+
"version": "3.7.0"
|
623 |
+
},
|
624 |
+
"toc": {
|
625 |
+
"base_numbering": 1,
|
626 |
+
"nav_menu": {
|
627 |
+
"height": "243.993px",
|
628 |
+
"width": "160px"
|
629 |
+
},
|
630 |
+
"number_sections": true,
|
631 |
+
"sideBar": true,
|
632 |
+
"skip_h1_title": false,
|
633 |
+
"title_cell": "Table of Contents",
|
634 |
+
"title_sidebar": "Contents",
|
635 |
+
"toc_cell": false,
|
636 |
+
"toc_position": {},
|
637 |
+
"toc_section_display": true,
|
638 |
+
"toc_window_display": false
|
639 |
+
}
|
640 |
+
},
|
641 |
+
"nbformat": 4,
|
642 |
+
"nbformat_minor": 4
|
643 |
+
}
|
legacy/install_fairseq.sh
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#NVIDIA CUDA download
|
2 |
+
wget "https://developer.nvidia.com/compute/cuda/10.0/Prod/local_installers/cuda_10.0.130_410.48_linux"
|
3 |
+
wget "http://developer.download.nvidia.com/compute/cuda/10.0/Prod/patches/1/cuda_10.0.130.1_linux.run"
|
4 |
+
|
5 |
+
## do not install drivers (See this: https://docs.nvidia.com/deploy/cuda-compatibility/index.html)
|
6 |
+
sudo sh "cuda_10.0.130_410.48_linux"
|
7 |
+
sudo sh "cuda_10.0.130.1_linux.run"
|
8 |
+
|
9 |
+
#Set environment variables
|
10 |
+
export CUDA_HOME=/usr/local/cuda-10.0
|
11 |
+
export PATH=$CUDA_HOME/bin:$PATH
|
12 |
+
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
|
13 |
+
|
14 |
+
# Install pytorch 1.2
|
15 |
+
python3 -m venv pytorch1.2
|
16 |
+
source pytorch1.2/bin/activate
|
17 |
+
which pip3
|
18 |
+
pip3 install torch==1.2.0 torchvision==0.4.0
|
19 |
+
|
20 |
+
# Install nccl
|
21 |
+
git clone https://github.com/NVIDIA/nccl.git
|
22 |
+
cd nccl
|
23 |
+
make src.build CUDA_HOME=$CUDA_HOME
|
24 |
+
sudo apt install build-essential devscripts debhelper fakeroot
|
25 |
+
make pkg.debian.build CUDA_HOME=$CUDA_HOME
|
26 |
+
sudo dpkg -i build/pkg/deb/libnccl2_2.7.8-1+cuda10.0_amd64.deb
|
27 |
+
sudo dpkg -i build/pkg/deb/libnccl-dev_2.7.8-1+cuda10.0_amd64.deb
|
28 |
+
sudo apt-get install -f
|
29 |
+
|
30 |
+
# Install Apex
|
31 |
+
git clone https://github.com/NVIDIA/apex
|
32 |
+
cd apex
|
33 |
+
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
|
34 |
+
--global-option="--deprecated_fused_adam" --global-option="--xentropy" \
|
35 |
+
--global-option="--fast_multihead_attn" ./
|
36 |
+
|
37 |
+
# Install PyArrow
|
38 |
+
pip install pyarrow
|
39 |
+
|
40 |
+
# Install fairseq
|
41 |
+
pip install --editable ./
|
42 |
+
|
43 |
+
# Install other dependencies
|
44 |
+
pip install sacrebleu
|
45 |
+
pip install tensorboardX --user
|
legacy/run_inference.sh
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
src_lang=${1:-hi}
|
2 |
+
tgt_lang=${2:-en}
|
3 |
+
bucket_path=${3:-gs://ai4b-anuvaad-nmt/baselines/transformer-base/baselines-${src_lang}-${tgt_lang}}
|
4 |
+
|
5 |
+
expdir=../baselines/baselines-${src_lang}-${tgt_lang}
|
6 |
+
|
7 |
+
if [[ -d $expdir ]]
|
8 |
+
then
|
9 |
+
echo "$expdir exists on your filesystem. Please delete this if you have made some changes to the bucket files and trying to redownload"
|
10 |
+
else
|
11 |
+
mkdir -p $expdir
|
12 |
+
mkdir -p $expdir/model
|
13 |
+
cd ../baselines
|
14 |
+
gsutil -m cp -r $bucket_path/vocab $expdir
|
15 |
+
gsutil -m cp -r $bucket_path/final_bin $expdir
|
16 |
+
gsutil -m cp $bucket_path/model/checkpoint_best.pt $expdir/model
|
17 |
+
cd ../indicTrans
|
18 |
+
fi
|
19 |
+
|
20 |
+
|
21 |
+
if [ $src_lang == 'hi' ] || [ $tgt_lang == 'hi' ]; then
|
22 |
+
#TEST_SETS=( wmt-news wat2021-devtest wat2020-devtest anuvaad-legal tico19 sap-documentation-benchmark all)
|
23 |
+
TEST_SETS=( wat2021-devtest wat2020-devtest wat-2018 wmt-news )
|
24 |
+
elif [ $src_lang == 'ta' ] || [ $tgt_lang == 'ta' ]; then
|
25 |
+
# TEST_SETS=( wmt-news wat2021-devtest wat2020-devtest anuvaad-legal tico19 all)
|
26 |
+
TEST_SETS=( wat2021-devtest wat2020-devtest wat-2018 wmt-news ufal-ta)
|
27 |
+
elif [ $src_lang == 'bn' ] || [ $tgt_lang == 'bn' ]; then
|
28 |
+
# TEST_SETS=( wat2021-devtest wat2020-devtest anuvaad-legal tico19 all)
|
29 |
+
TEST_SETS=( wat2021-devtest wat2020-devtest wat-2018)
|
30 |
+
elif [ $src_lang == 'gu' ] || [ $tgt_lang == 'gu' ]; then
|
31 |
+
# TEST_SETS=( wmt-news wat2021-devtest wat2020-devtest all)
|
32 |
+
TEST_SETS=( wat2021-devtest wat2020-devtest wmt-news )
|
33 |
+
elif [ $src_lang == 'as' ] || [ $tgt_lang == 'as' ]; then
|
34 |
+
TEST_SETS=( pmi )
|
35 |
+
elif [ $src_lang == 'kn' ] || [ $tgt_lang == 'kn' ]; then
|
36 |
+
# TEST_SETS=( wat2021-devtest anuvaad-legal all)
|
37 |
+
TEST_SETS=( wat2021-devtest )
|
38 |
+
elif [ $src_lang == 'ml' ] || [ $tgt_lang == 'ml' ]; then
|
39 |
+
# TEST_SETS=( wat2021-devtest wat2020-devtest anuvaad-legal all)
|
40 |
+
TEST_SETS=( wat2021-devtest wat2020-devtest wat-2018)
|
41 |
+
elif [ $src_lang == 'mr' ] || [ $tgt_lang == 'mr' ]; then
|
42 |
+
# TEST_SETS=( wat2021-devtest wat2020-devtest all)
|
43 |
+
TEST_SETS=( wat2021-devtest wat2020-devtest )
|
44 |
+
elif [ $src_lang == 'or' ] || [ $tgt_lang == 'or' ]; then
|
45 |
+
TEST_SETS=( wat2021-devtest )
|
46 |
+
elif [ $src_lang == 'pa' ] || [ $tgt_lang == 'pa' ]; then
|
47 |
+
TEST_SETS=( wat2021-devtest )
|
48 |
+
elif [ $src_lang == 'te' ] || [ $tgt_lang == 'te' ]; then
|
49 |
+
# TEST_SETS=( wat2021-devtest wat2020-devtest anuvaad-legal all )
|
50 |
+
TEST_SETS=( wat2021-devtest wat2020-devtest wat-2018)
|
51 |
+
fi
|
52 |
+
|
53 |
+
if [ $src_lang == 'en' ]; then
|
54 |
+
indic_lang=$tgt_lang
|
55 |
+
else
|
56 |
+
indic_lang=$src_lang
|
57 |
+
fi
|
58 |
+
|
59 |
+
|
60 |
+
for tset in ${TEST_SETS[@]};do
|
61 |
+
echo $tset $src_lang $tgt_lang
|
62 |
+
if [ $tset == 'wat2021-devtest' ]; then
|
63 |
+
SRC_FILE=${expdir}/benchmarks/$tset/test.$src_lang
|
64 |
+
REF_FILE=${expdir}/benchmarks/$tset/test.$tgt_lang
|
65 |
+
else
|
66 |
+
SRC_FILE=${expdir}/benchmarks/$tset/en-${indic_lang}/test.$src_lang
|
67 |
+
REF_FILE=${expdir}/benchmarks/$tset/en-${indic_lang}/test.$tgt_lang
|
68 |
+
fi
|
69 |
+
RESULTS_DIR=${expdir}/results/$tset
|
70 |
+
|
71 |
+
mkdir -p $RESULTS_DIR
|
72 |
+
|
73 |
+
bash translate.sh $SRC_FILE $RESULTS_DIR/${src_lang}-${tgt_lang} $src_lang $tgt_lang $expdir $REF_FILE
|
74 |
+
# for newline between different outputs
|
75 |
+
echo
|
76 |
+
done
|
77 |
+
# send the results to the bucket
|
78 |
+
gsutil -m cp -r $expdir/results $bucket_path
|
79 |
+
# clear up the space in the instance
|
80 |
+
# rm -r $expdir
|
legacy/run_joint_inference.sh
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
src_lang=${1:-en}
|
2 |
+
tgt_lang=${2:-indic}
|
3 |
+
bucket_path=${3:-gs://ai4b-anuvaad-nmt/models/transformer-4x/indictrans-${src_lang}-${tgt_lang}}
|
4 |
+
|
5 |
+
mkdir -p ../baselines
|
6 |
+
expdir=../baselines/baselines-${src_lang}-${tgt_lang}
|
7 |
+
|
8 |
+
if [[ -d $expdir ]]
|
9 |
+
then
|
10 |
+
echo "$expdir exists on your filesystem."
|
11 |
+
else
|
12 |
+
cd ../baselines
|
13 |
+
mkdir -p baselines-${src_lang}-${tgt_lang}/model
|
14 |
+
mkdir -p baselines-${src_lang}-${tgt_lang}/final_bin
|
15 |
+
cd baselines-${src_lang}-${tgt_lang}/model
|
16 |
+
gsutil -m cp $bucket_path/model/checkpoint_best.pt .
|
17 |
+
cd ..
|
18 |
+
gsutil -m cp $bucket_path/vocab .
|
19 |
+
gsutil -m cp $bucket_path/final_bin/dict.* final_bin
|
20 |
+
cd ../indicTrans
|
21 |
+
fi
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
if [ $src_lang == 'hi' ] || [ $tgt_lang == 'hi' ]; then
|
28 |
+
TEST_SETS=( wmt-news wat2021-devtest wat2020-devtest anuvaad-legal tico19 sap-documentation-benchmark all)
|
29 |
+
elif [ $src_lang == 'ta' ] || [ $tgt_lang == 'ta' ]; then
|
30 |
+
TEST_SETS=( wmt-news wat2021-devtest wat2020-devtest anuvaad-legal tico19 all)
|
31 |
+
elif [ $src_lang == 'bn' ] || [ $tgt_lang == 'bn' ]; then
|
32 |
+
TEST_SETS=( wat2021-devtest wat2020-devtest anuvaad-legal tico19 all)
|
33 |
+
elif [ $src_lang == 'gu' ] || [ $tgt_lang == 'gu' ]; then
|
34 |
+
TEST_SETS=( wmt-news wat2021-devtest wat2020-devtest all)
|
35 |
+
elif [ $src_lang == 'as' ] || [ $tgt_lang == 'as' ]; then
|
36 |
+
TEST_SETS=( all )
|
37 |
+
elif [ $src_lang == 'kn' ] || [ $tgt_lang == 'kn' ]; then
|
38 |
+
TEST_SETS=( wat2021-devtest anuvaad-legal all)
|
39 |
+
elif [ $src_lang == 'ml' ] || [ $tgt_lang == 'ml' ]; then
|
40 |
+
TEST_SETS=( wat2021-devtest wat2020-devtest anuvaad-legal all)
|
41 |
+
elif [ $src_lang == 'mr' ] || [ $tgt_lang == 'mr' ]; then
|
42 |
+
TEST_SETS=( wat2021-devtest wat2020-devtest all)
|
43 |
+
elif [ $src_lang == 'or' ] || [ $tgt_lang == 'or' ]; then
|
44 |
+
TEST_SETS=( all )
|
45 |
+
elif [ $src_lang == 'pa' ] || [ $tgt_lang == 'pa' ]; then
|
46 |
+
TEST_SETS=( all )
|
47 |
+
elif [ $src_lang == 'te' ] || [ $tgt_lang == 'te' ]; then
|
48 |
+
TEST_SETS=( wat2021-devtest wat2020-devtest anuvaad-legal all )
|
49 |
+
fi
|
50 |
+
|
51 |
+
if [ $src_lang == 'en' ]; then
|
52 |
+
indic_lang=$tgt_lang
|
53 |
+
else
|
54 |
+
indic_lang=$src_lang
|
55 |
+
fi
|
56 |
+
|
57 |
+
|
58 |
+
for tset in ${TEST_SETS[@]};do
|
59 |
+
echo $tset $src_lang $tgt_lang
|
60 |
+
if [ $tset == 'wat2021-devtest' ]; then
|
61 |
+
SRC_FILE=${expdir}/devtest/$tset/test.$src_lang
|
62 |
+
REF_FILE=${expdir}/devtest/$tset/test.$tgt_lang
|
63 |
+
else
|
64 |
+
SRC_FILE=${expdir}/devtest/$tset/en-${indic_lang}/test.$src_lang
|
65 |
+
REF_FILE=${expdir}/devtest/$tset/en-${indic_lang}/test.$tgt_lang
|
66 |
+
fi
|
67 |
+
RESULTS_DIR=${expdir}/results/$tset
|
68 |
+
|
69 |
+
mkdir -p $RESULTS_DIR
|
70 |
+
|
71 |
+
bash joint_translate.sh $SRC_FILE $RESULTS_DIR/${src_lang}-${tgt_lang} $src_lang $tgt_lang $expdir $REF_FILE
|
72 |
+
# for newline between different outputs
|
73 |
+
echo
|
74 |
+
done
|
legacy/tpu_training_instructions.md
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Instructions to run on Google cloud TPUs
|
2 |
+
Before starting these steps, make sure to prepare the dataset (normalization -> bpe -> .. -> binarization) following the steps in indicTrans workflow or do these steps on a cpu instance before launching the tpu instance (to save time and costs)
|
3 |
+
|
4 |
+
### Creating TPU instance
|
5 |
+
|
6 |
+
- Create a cpu instance on gcp with `torch-xla` image like:
|
7 |
+
```bash
|
8 |
+
gcloud compute --project=${PROJECT_ID} instances create <name for your instance> \
|
9 |
+
--zone=<zone> \
|
10 |
+
--machine-type=n1-standard-16 \
|
11 |
+
--image-family=torch-xla \
|
12 |
+
--image-project=ml-images \
|
13 |
+
--boot-disk-size=200GB \
|
14 |
+
--scopes=https://www.googleapis.com/auth/cloud-platform
|
15 |
+
```
|
16 |
+
- Once the instance is created, Launch a Cloud TPU (from your cpu vm instance) using the following command (you can change the `accelerator_type` according to your needs):
|
17 |
+
```bash
|
18 |
+
gcloud compute tpus create <name for your TPU> \
|
19 |
+
--zone=<zone> \
|
20 |
+
--network=default \
|
21 |
+
--version=pytorch-1.7 \
|
22 |
+
--accelerator-type=v3-8
|
23 |
+
```
|
24 |
+
(or)
|
25 |
+
Create a new tpu using the GUI in https://console.cloud.google.com/compute/tpus and make sure to select `version` as `pytorch 1.7`.
|
26 |
+
|
27 |
+
- Once the tpu is launched, identify its ip address:
|
28 |
+
```bash
|
29 |
+
# you can run this inside cpu instance and note down the IP address which is located under the NETWORK_ENDPOINTS column
|
30 |
+
gcloud compute tpus list --zone=us-central1-a
|
31 |
+
```
|
32 |
+
(or)
|
33 |
+
Go to https://console.cloud.google.com/compute/tpus and note down ip address for the created TPU from the `interal ip` column
|
34 |
+
|
35 |
+
### Installing Fairseq, getting data on the cpu instance
|
36 |
+
|
37 |
+
- Activate the `torch xla 1.7` conda environment and install necessary libs for IndicTrans (**Excluding FairSeq**):
|
38 |
+
```bash
|
39 |
+
conda activate torch-xla-1.7
|
40 |
+
pip install sacremoses pandas mock sacrebleu tensorboardX pyarrow
|
41 |
+
```
|
42 |
+
- Configure environment variables for TPU:
|
43 |
+
```bash
|
44 |
+
export TPU_IP_ADDRESS=ip-address; \
|
45 |
+
export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
|
46 |
+
```
|
47 |
+
- Download the prepared binarized data for FairSeq
|
48 |
+
|
49 |
+
- Clone the latest version of Fairseq (this supports tpu) and install from source. There is an [issue](https://github.com/pytorch/fairseq/issues/3259) with the latest commit and hence we use a different commit to install from source (This may have been fixed in the latest master but we have not tested it.)
|
50 |
+
```bash
|
51 |
+
git clone https://github.com/pytorch/fairseq.git
|
52 |
+
git checkout da9eaba12d82b9bfc1442f0e2c6fc1b895f4d35d
|
53 |
+
pip install --editable ./
|
54 |
+
```
|
55 |
+
|
56 |
+
- Start TPU training
|
57 |
+
```bash
|
58 |
+
# this is for using all tpu cores
|
59 |
+
export MKL_SERVICE_FORCE_INTEL=1
|
60 |
+
|
61 |
+
fairseq-train {expdir}/exp2_m2o_baseline/final_bin \
|
62 |
+
--max-source-positions=200 \
|
63 |
+
--max-target-positions=200 \
|
64 |
+
--max-update=1000000 \
|
65 |
+
--save-interval=5 \
|
66 |
+
--arch=transformer \
|
67 |
+
--attention-dropout=0.1 \
|
68 |
+
--criterion=label_smoothed_cross_entropy \
|
69 |
+
--source-lang=SRC \
|
70 |
+
--lr-scheduler=inverse_sqrt \
|
71 |
+
--skip-invalid-size-inputs-valid-test \
|
72 |
+
--target-lang=TGT \
|
73 |
+
--label-smoothing=0.1 \
|
74 |
+
--update-freq=1 \
|
75 |
+
--optimizer adam \
|
76 |
+
--adam-betas '(0.9, 0.98)' \
|
77 |
+
--warmup-init-lr 1e-07 \
|
78 |
+
--lr 0.0005 \
|
79 |
+
--warmup-updates 4000 \
|
80 |
+
--dropout 0.2 \
|
81 |
+
--weight-decay 0.0 \
|
82 |
+
--tpu \
|
83 |
+
--distributed-world-size 8 \
|
84 |
+
--max-tokens 8192 \
|
85 |
+
--num-batch-buckets 8 \
|
86 |
+
--tensorboard-logdir {expdir}/exp2_m2o_baseline/tensorboard \
|
87 |
+
--save-dir {expdir}/exp2_m2o_baseline/model \
|
88 |
+
--keep-last-epochs 5 \
|
89 |
+
--patience 5
|
90 |
+
```
|
91 |
+
|
92 |
+
**Note** While training, we noticed that the training was slower on tpus, compared to using multiple GPUs, we have documented some issues and [filed an issue](https://github.com/pytorch/fairseq/issues/3317) at fairseq repo for advice. We'll update this section as we learn more about efficient training on TPUs. Also feel free to open an issue/pull request if you find a bug or know an efficient method to make code train faster on tpus.
|
legacy/translate.sh
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
echo `date`
|
3 |
+
infname=$1
|
4 |
+
outfname=$2
|
5 |
+
src_lang=$3
|
6 |
+
tgt_lang=$4
|
7 |
+
exp_dir=$5
|
8 |
+
ref_fname=$6
|
9 |
+
|
10 |
+
if [ $src_lang == 'en' ]; then
|
11 |
+
SRC_PREFIX='TGT'
|
12 |
+
TGT_PREFIX='SRC'
|
13 |
+
else
|
14 |
+
SRC_PREFIX='SRC'
|
15 |
+
TGT_PREFIX='TGT'
|
16 |
+
fi
|
17 |
+
|
18 |
+
#`dirname $0`/env.sh
|
19 |
+
SUBWORD_NMT_DIR='subword-nmt'
|
20 |
+
model_dir=$exp_dir/model
|
21 |
+
data_bin_dir=$exp_dir/final_bin
|
22 |
+
|
23 |
+
### normalization and script conversion
|
24 |
+
|
25 |
+
echo "Applying normalization and script conversion"
|
26 |
+
input_size=`python preprocess_translate.py $infname $outfname.norm $src_lang`
|
27 |
+
echo "Number of sentences in input: $input_size"
|
28 |
+
|
29 |
+
### apply BPE to input file
|
30 |
+
|
31 |
+
echo "Applying BPE"
|
32 |
+
python $SUBWORD_NMT_DIR/subword_nmt/apply_bpe.py \
|
33 |
+
-c $exp_dir/vocab/bpe_codes.32k.${SRC_PREFIX}_${TGT_PREFIX} \
|
34 |
+
--vocabulary $exp_dir/vocab/vocab.$SRC_PREFIX \
|
35 |
+
--vocabulary-threshold 5 \
|
36 |
+
< $outfname.norm \
|
37 |
+
> $outfname.bpe
|
38 |
+
|
39 |
+
# not needed for joint training
|
40 |
+
# echo "Adding language tags"
|
41 |
+
# python add_tags_translate.py $outfname._bpe $outfname.bpe $src_lang $tgt_lang
|
42 |
+
|
43 |
+
### run decoder
|
44 |
+
|
45 |
+
echo "Decoding"
|
46 |
+
|
47 |
+
src_input_bpe_fname=$outfname.bpe
|
48 |
+
tgt_output_fname=$outfname
|
49 |
+
fairseq-interactive $data_bin_dir \
|
50 |
+
-s $SRC_PREFIX -t $TGT_PREFIX \
|
51 |
+
--distributed-world-size 1 \
|
52 |
+
--path $model_dir/checkpoint_best.pt \
|
53 |
+
--batch-size 64 --buffer-size 2500 --beam 5 --remove-bpe \
|
54 |
+
--skip-invalid-size-inputs-valid-test \
|
55 |
+
--input $src_input_bpe_fname > $tgt_output_fname.log 2>&1
|
56 |
+
|
57 |
+
|
58 |
+
echo "Extracting translations, script conversion and detokenization"
|
59 |
+
python postprocess_translate.py $tgt_output_fname.log $tgt_output_fname $input_size $tgt_lang
|
60 |
+
if [ $src_lang == 'en' ]; then
|
61 |
+
# indicnlp tokenize the output files before evaluation
|
62 |
+
input_size=`python preprocess_translate.py $ref_fname $ref_fname.tok $tgt_lang`
|
63 |
+
input_size=`python preprocess_translate.py $tgt_output_fname $tgt_output_fname.tok $tgt_lang`
|
64 |
+
sacrebleu --tokenize none $ref_fname.tok < $tgt_output_fname.tok
|
65 |
+
else
|
66 |
+
# indic to en models
|
67 |
+
sacrebleu $ref_fname < $tgt_output_fname
|
68 |
+
fi
|
69 |
+
echo `date`
|
70 |
+
echo "Translation completed"
|
model_configs/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import custom_transformer
|
model_configs/custom_transformer.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fairseq.models import register_model_architecture
|
2 |
+
from fairseq.models.transformer import base_architecture
|
3 |
+
|
4 |
+
|
5 |
+
@register_model_architecture("transformer", "transformer_2x")
|
6 |
+
def transformer_big(args):
|
7 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
8 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
9 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
10 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
11 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
|
12 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
|
13 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
14 |
+
base_architecture(args)
|
15 |
+
|
16 |
+
|
17 |
+
@register_model_architecture("transformer", "transformer_4x")
|
18 |
+
def transformer_huge(args):
|
19 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1536)
|
20 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
21 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
22 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
23 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1536)
|
24 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
|
25 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
26 |
+
base_architecture(args)
|
27 |
+
|
28 |
+
|
29 |
+
@register_model_architecture("transformer", "transformer_9x")
|
30 |
+
def transformer_xlarge(args):
|
31 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 2048)
|
32 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 8192)
|
33 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
34 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
35 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048)
|
36 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192)
|
37 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
38 |
+
base_architecture(args)
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sacremoses
|
2 |
+
pandas
|
3 |
+
mock
|
4 |
+
sacrebleu
|
5 |
+
pyarrow
|
6 |
+
indic-nlp-library
|
7 |
+
mosestokenizer
|
8 |
+
subword-nmt
|
9 |
+
numpy
|
10 |
+
tensorboardX
|
11 |
+
git+https://github.com/pytorch/fairseq.git
|
scripts/__init__.py
ADDED
File without changes
|
scripts/add_joint_tags_translate.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from tqdm import tqdm
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
def add_token(sent, tag_infos):
|
7 |
+
""" add special tokens specified by tag_infos to each element in list
|
8 |
+
|
9 |
+
tag_infos: list of tuples (tag_type,tag)
|
10 |
+
|
11 |
+
each tag_info results in a token of the form: __{tag_type}__{tag}__
|
12 |
+
|
13 |
+
"""
|
14 |
+
|
15 |
+
tokens = []
|
16 |
+
for tag_type, tag in tag_infos:
|
17 |
+
token = '__' + tag_type + '__' + tag + '__'
|
18 |
+
tokens.append(token)
|
19 |
+
|
20 |
+
return ' '.join(tokens) + ' ' + sent
|
21 |
+
|
22 |
+
|
23 |
+
def generate_lang_tag_iterator(infname):
|
24 |
+
with open(infname, 'r', encoding='utf-8') as infile:
|
25 |
+
for line in infile:
|
26 |
+
src, tgt, count = line.strip().split('\t')
|
27 |
+
count = int(count)
|
28 |
+
for _ in range(count):
|
29 |
+
yield (src, tgt)
|
30 |
+
|
31 |
+
|
32 |
+
if __name__ == '__main__':
|
33 |
+
|
34 |
+
expdir = sys.argv[1]
|
35 |
+
dset = sys.argv[2]
|
36 |
+
|
37 |
+
src_fname = '{expdir}/bpe/{dset}.SRC'.format(
|
38 |
+
expdir=expdir, dset=dset)
|
39 |
+
tgt_fname = '{expdir}/bpe/{dset}.TGT'.format(
|
40 |
+
expdir=expdir, dset=dset)
|
41 |
+
meta_fname = '{expdir}/data/{dset}_lang_pairs.txt'.format(
|
42 |
+
expdir=expdir, dset=dset)
|
43 |
+
|
44 |
+
out_src_fname = '{expdir}/final/{dset}.SRC'.format(
|
45 |
+
expdir=expdir, dset=dset)
|
46 |
+
out_tgt_fname = '{expdir}/final/{dset}.TGT'.format(
|
47 |
+
expdir=expdir, dset=dset)
|
48 |
+
lang_tag_iterator = generate_lang_tag_iterator(meta_fname)
|
49 |
+
|
50 |
+
os.makedirs('{expdir}/final'.format(expdir=expdir), exist_ok=True)
|
51 |
+
|
52 |
+
with open(src_fname, 'r', encoding='utf-8') as srcfile, \
|
53 |
+
open(tgt_fname, 'r', encoding='utf-8') as tgtfile, \
|
54 |
+
open(out_src_fname, 'w', encoding='utf-8') as outsrcfile, \
|
55 |
+
open(out_tgt_fname, 'w', encoding='utf-8') as outtgtfile:
|
56 |
+
|
57 |
+
for (l1, l2), src_sent, tgt_sent in tqdm(zip(lang_tag_iterator,
|
58 |
+
srcfile, tgtfile)):
|
59 |
+
outsrcfile.write(add_token(src_sent.strip(), [
|
60 |
+
('src', l1), ('tgt', l2)]) + '\n')
|
61 |
+
outtgtfile.write(tgt_sent.strip() + '\n')
|
scripts/add_tags_translate.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
|
4 |
+
def add_token(sent, tag_infos):
|
5 |
+
""" add special tokens specified by tag_infos to each element in list
|
6 |
+
|
7 |
+
tag_infos: list of tuples (tag_type,tag)
|
8 |
+
|
9 |
+
each tag_info results in a token of the form: __{tag_type}__{tag}__
|
10 |
+
|
11 |
+
"""
|
12 |
+
|
13 |
+
tokens = []
|
14 |
+
for tag_type, tag in tag_infos:
|
15 |
+
token = '__' + tag_type + '__' + tag + '__'
|
16 |
+
tokens.append(token)
|
17 |
+
|
18 |
+
return ' '.join(tokens) + ' ' + sent
|
19 |
+
|
20 |
+
|
21 |
+
if __name__ == '__main__':
|
22 |
+
|
23 |
+
infname = sys.argv[1]
|
24 |
+
outfname = sys.argv[2]
|
25 |
+
src_lang = sys.argv[3]
|
26 |
+
tgt_lang = sys.argv[4]
|
27 |
+
|
28 |
+
with open(infname, 'r', encoding='utf-8') as infile, \
|
29 |
+
open(outfname, 'w', encoding='utf-8') as outfile:
|
30 |
+
for line in infile:
|
31 |
+
outstr = add_token(
|
32 |
+
line.strip(), [('src', src_lang), ('tgt', tgt_lang)])
|
33 |
+
outfile.write(outstr + '\n')
|
scripts/clean_vocab.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import codecs
|
3 |
+
|
4 |
+
def clean_vocab(in_vocab_fname, out_vocab_fname):
|
5 |
+
with codecs.open(in_vocab_fname, "r", encoding="utf-8") as infile, codecs.open(
|
6 |
+
out_vocab_fname, "w", encoding="utf-8"
|
7 |
+
) as outfile:
|
8 |
+
for i, line in enumerate(infile):
|
9 |
+
fields = line.strip("\r\n ").split(" ")
|
10 |
+
if len(fields) == 2:
|
11 |
+
outfile.write(line)
|
12 |
+
if len(fields) != 2:
|
13 |
+
print("{}: {}".format(i, line.strip()))
|
14 |
+
for c in line:
|
15 |
+
print("{}:{}".format(c, hex(ord(c))))
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == "__main__":
|
19 |
+
clean_vocab(sys.argv[1], sys.argv[2])
|
scripts/concat_joint_data.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tqdm import tqdm
|
3 |
+
import sys
|
4 |
+
|
5 |
+
LANGS = [
|
6 |
+
"as",
|
7 |
+
"bn",
|
8 |
+
"gu",
|
9 |
+
"hi",
|
10 |
+
"kn",
|
11 |
+
"ml",
|
12 |
+
"mr",
|
13 |
+
"or",
|
14 |
+
"pa",
|
15 |
+
"ta",
|
16 |
+
"te",
|
17 |
+
#"ur"
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
def add_token(sent, tag_infos):
|
22 |
+
""" add special tokens specified by tag_infos to each element in list
|
23 |
+
|
24 |
+
tag_infos: list of tuples (tag_type,tag)
|
25 |
+
|
26 |
+
each tag_info results in a token of the form: __{tag_type}__{tag}__
|
27 |
+
|
28 |
+
"""
|
29 |
+
|
30 |
+
tokens = []
|
31 |
+
for tag_type, tag in tag_infos:
|
32 |
+
token = '__' + tag_type + '__' + tag + '__'
|
33 |
+
tokens.append(token)
|
34 |
+
|
35 |
+
return ' '.join(tokens) + ' ' + sent
|
36 |
+
|
37 |
+
|
38 |
+
def concat_data(data_dir, outdir, lang_pair_list,
|
39 |
+
out_src_lang='SRC', out_trg_lang='TGT', split='train'):
|
40 |
+
"""
|
41 |
+
data_dir: input dir, contains directories for language pairs named l1-l2
|
42 |
+
"""
|
43 |
+
os.makedirs(outdir, exist_ok=True)
|
44 |
+
|
45 |
+
out_src_fname = '{}/{}.{}'.format(outdir, split, out_src_lang)
|
46 |
+
out_trg_fname = '{}/{}.{}'.format(outdir, split, out_trg_lang)
|
47 |
+
# out_meta_fname='{}/metadata.txt'.format(outdir)
|
48 |
+
|
49 |
+
print()
|
50 |
+
print(out_src_fname)
|
51 |
+
print(out_trg_fname)
|
52 |
+
# print(out_meta_fname)
|
53 |
+
|
54 |
+
# concatenate train data
|
55 |
+
if os.path.isfile(out_src_fname):
|
56 |
+
os.unlink(out_src_fname)
|
57 |
+
if os.path.isfile(out_trg_fname):
|
58 |
+
os.unlink(out_trg_fname)
|
59 |
+
# if os.path.isfile(out_meta_fname):
|
60 |
+
# os.unlink(out_meta_fname)
|
61 |
+
|
62 |
+
for src_lang, trg_lang in tqdm(lang_pair_list):
|
63 |
+
print('src: {}, tgt:{}'.format(src_lang, trg_lang))
|
64 |
+
|
65 |
+
in_src_fname = '{}/{}-{}/{}.{}'.format(
|
66 |
+
data_dir, src_lang, trg_lang, split, src_lang)
|
67 |
+
in_trg_fname = '{}/{}-{}/{}.{}'.format(
|
68 |
+
data_dir, src_lang, trg_lang, split, trg_lang)
|
69 |
+
|
70 |
+
if not os.path.exists(in_src_fname):
|
71 |
+
continue
|
72 |
+
if not os.path.exists(in_trg_fname):
|
73 |
+
continue
|
74 |
+
|
75 |
+
print(in_src_fname)
|
76 |
+
os.system('cat {} >> {}'.format(in_src_fname, out_src_fname))
|
77 |
+
|
78 |
+
print(in_trg_fname)
|
79 |
+
os.system('cat {} >> {}'.format(in_trg_fname, out_trg_fname))
|
80 |
+
|
81 |
+
|
82 |
+
# with open('{}/lang_pairs.txt'.format(outdir),'w',encoding='utf-8') as lpfile:
|
83 |
+
# lpfile.write('\n'.join( [ '-'.join(x) for x in lang_pair_list ] ))
|
84 |
+
|
85 |
+
corpus_stats(data_dir, outdir, lang_pair_list, split)
|
86 |
+
|
87 |
+
|
88 |
+
def corpus_stats(data_dir, outdir, lang_pair_list, split):
|
89 |
+
"""
|
90 |
+
data_dir: input dir, contains directories for language pairs named l1-l2
|
91 |
+
"""
|
92 |
+
|
93 |
+
with open('{}/{}_lang_pairs.txt'.format(outdir, split), 'w', encoding='utf-8') as lpfile:
|
94 |
+
|
95 |
+
for src_lang, trg_lang in tqdm(lang_pair_list):
|
96 |
+
print('src: {}, tgt:{}'.format(src_lang, trg_lang))
|
97 |
+
|
98 |
+
in_src_fname = '{}/{}-{}/{}.{}'.format(
|
99 |
+
data_dir, src_lang, trg_lang, split, src_lang)
|
100 |
+
# in_trg_fname='{}/{}-{}/train.{}'.format(data_dir,src_lang,trg_lang,trg_lang)
|
101 |
+
if not os.path.exists(in_src_fname):
|
102 |
+
continue
|
103 |
+
|
104 |
+
print(in_src_fname)
|
105 |
+
corpus_size = 0
|
106 |
+
with open(in_src_fname, 'r', encoding='utf-8') as infile:
|
107 |
+
corpus_size = sum(map(lambda x: 1, infile))
|
108 |
+
|
109 |
+
lpfile.write('{}\t{}\t{}\n'.format(
|
110 |
+
src_lang, trg_lang, corpus_size))
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == '__main__':
|
114 |
+
|
115 |
+
in_dir = sys.argv[1]
|
116 |
+
out_dir = sys.argv[2]
|
117 |
+
src_lang = sys.argv[3]
|
118 |
+
tgt_lang = sys.argv[4]
|
119 |
+
split = sys.argv[5]
|
120 |
+
lang_pair_list = []
|
121 |
+
|
122 |
+
if src_lang == 'en':
|
123 |
+
for lang in LANGS:
|
124 |
+
lang_pair_list.append(['en', lang])
|
125 |
+
else:
|
126 |
+
for lang in LANGS:
|
127 |
+
lang_pair_list.append([lang, 'en'])
|
128 |
+
|
129 |
+
concat_data(in_dir, out_dir, lang_pair_list, split=split)
|
130 |
+
|
scripts/extract_non_english_pairs.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
import os
|
3 |
+
from collections import defaultdict
|
4 |
+
|
5 |
+
|
6 |
+
def read_file(fname):
|
7 |
+
with open(fname, "r", encoding="utf-8") as infile:
|
8 |
+
for line in infile:
|
9 |
+
yield line.strip()
|
10 |
+
|
11 |
+
|
12 |
+
def extract_non_english_pairs(indir, outdir, LANGS):
|
13 |
+
"""
|
14 |
+
Extracts non-english pair parallel corpora
|
15 |
+
|
16 |
+
indir: contains english centric data in the following form:
|
17 |
+
- directory named en-xx for language xx
|
18 |
+
- each directory contains a train.en and train.xx
|
19 |
+
outdir: output directory to store mined data for each pair.
|
20 |
+
One directory is created for each pair.
|
21 |
+
LANGS: list of languages in the corpus (other than English).
|
22 |
+
The language codes must correspond to the ones used in the
|
23 |
+
files and directories in indir. Prefarably, sort the languages
|
24 |
+
in this list in alphabetic order. outdir will contain data for xx-yy,
|
25 |
+
but not for yy-xx, so it will be convenient to have this list in sorted order.
|
26 |
+
"""
|
27 |
+
|
28 |
+
for i in tqdm(range(len(LANGS) - 1)):
|
29 |
+
print()
|
30 |
+
for j in range(i + 1, len(LANGS)):
|
31 |
+
lang1 = LANGS[i]
|
32 |
+
lang2 = LANGS[j]
|
33 |
+
# print()
|
34 |
+
print("{} {}".format(lang1, lang2))
|
35 |
+
|
36 |
+
fname1 = "{}/en-{}/train.en".format(indir, lang1)
|
37 |
+
fname2 = "{}/en-{}/train.en".format(indir, lang2)
|
38 |
+
# print(fname1)
|
39 |
+
# print(fname2)
|
40 |
+
enset_l1 = set(read_file(fname1))
|
41 |
+
common_en_set = enset_l1.intersection(read_file(fname2))
|
42 |
+
|
43 |
+
## this block should be used if you want to consider multiple translations.
|
44 |
+
# il_fname1 = "{}/en-{}/train.{}".format(indir, lang1, lang1)
|
45 |
+
# en_lang1_dict = defaultdict(list)
|
46 |
+
# for en_line, il_line in zip(read_file(fname1), read_file(il_fname1)):
|
47 |
+
# if en_line in common_en_set:
|
48 |
+
# en_lang1_dict[en_line].append(il_line)
|
49 |
+
|
50 |
+
# # this block should be used if you DONT to consider multiple translation.
|
51 |
+
il_fname1='{}/en-{}/train.{}'.format(indir,lang1,lang1)
|
52 |
+
en_lang1_dict={}
|
53 |
+
for en_line,il_line in zip(read_file(fname1),read_file(il_fname1)):
|
54 |
+
if en_line in common_en_set:
|
55 |
+
en_lang1_dict[en_line]=il_line
|
56 |
+
|
57 |
+
os.makedirs("{}/{}-{}".format(outdir, lang1, lang2), exist_ok=True)
|
58 |
+
out_l1_fname = "{o}/{l1}-{l2}/train.{l1}".format(
|
59 |
+
o=outdir, l1=lang1, l2=lang2
|
60 |
+
)
|
61 |
+
out_l2_fname = "{o}/{l1}-{l2}/train.{l2}".format(
|
62 |
+
o=outdir, l1=lang1, l2=lang2
|
63 |
+
)
|
64 |
+
|
65 |
+
il_fname2 = "{}/en-{}/train.{}".format(indir, lang2, lang2)
|
66 |
+
with open(out_l1_fname, "w", encoding="utf-8") as out_l1_file, open(
|
67 |
+
out_l2_fname, "w", encoding="utf-8"
|
68 |
+
) as out_l2_file:
|
69 |
+
for en_line, il_line in zip(read_file(fname2), read_file(il_fname2)):
|
70 |
+
if en_line in en_lang1_dict:
|
71 |
+
|
72 |
+
# this block should be used if you want to consider multiple tranlations.
|
73 |
+
for il_line_lang1 in en_lang1_dict[en_line]:
|
74 |
+
# lang1_line, lang2_line = il_line_lang1, il_line
|
75 |
+
# out_l1_file.write(lang1_line + "\n")
|
76 |
+
# out_l2_file.write(lang2_line + "\n")
|
77 |
+
|
78 |
+
# this block should be used if you DONT to consider multiple translation.
|
79 |
+
lang1_line, lang2_line = en_lang1_dict[en_line], il_line
|
80 |
+
out_l1_file.write(lang1_line+'\n')
|
81 |
+
out_l2_file.write(lang2_line+'\n')
|
82 |
+
|
83 |
+
|
84 |
+
def get_extracted_stats(outdir, LANGS):
|
85 |
+
"""
|
86 |
+
gathers stats from the extracted directories
|
87 |
+
|
88 |
+
outdir: output directory to store mined data for each pair.
|
89 |
+
One directory is created for each pair.
|
90 |
+
LANGS: list of languages in the corpus (other than languages).
|
91 |
+
The language codes must correspond to the ones used in the
|
92 |
+
files and directories in indir. Prefarably, sort the languages
|
93 |
+
in this list in alphabetic order. outdir will contain data for xx-yy,
|
94 |
+
"""
|
95 |
+
common_stats = []
|
96 |
+
for i in tqdm(range(len(LANGS) - 1)):
|
97 |
+
for j in range(i + 1, len(LANGS)):
|
98 |
+
lang1 = LANGS[i]
|
99 |
+
lang2 = LANGS[j]
|
100 |
+
|
101 |
+
out_l1_fname = "{o}/{l1}-{l2}/train.{l1}".format(
|
102 |
+
o=outdir, l1=lang1, l2=lang2
|
103 |
+
)
|
104 |
+
|
105 |
+
cnt = sum([1 for _ in read_file(out_l1_fname)])
|
106 |
+
common_stats.append((lang1, lang2, cnt))
|
107 |
+
common_stats.append((lang2, lang1, cnt))
|
108 |
+
return common_stats
|
scripts/postprocess_score.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
def postprocess(
|
4 |
+
infname, outfname, input_size
|
5 |
+
):
|
6 |
+
"""
|
7 |
+
parse fairseq interactive output, convert script back to native Indic script (in case of Indic languages) and detokenize.
|
8 |
+
|
9 |
+
infname: fairseq log file
|
10 |
+
outfname: output file of translation (sentences not translated contain the dummy string 'DUMMY_OUTPUT'
|
11 |
+
input_size: expected number of output sentences
|
12 |
+
"""
|
13 |
+
|
14 |
+
consolidated_testoutput = []
|
15 |
+
# with open(infname,'r',encoding='utf-8') as infile:
|
16 |
+
# consolidated_testoutput= list(map(lambda x: x.strip(), filter(lambda x: x.startswith('H-'),infile) ))
|
17 |
+
# consolidated_testoutput.sort(key=lambda x: int(x.split('\t')[0].split('-')[1]))
|
18 |
+
# consolidated_testoutput=[ x.split('\t')[2] for x in consolidated_testoutput ]
|
19 |
+
|
20 |
+
consolidated_testoutput = [(x, 0.0, "") for x in range(input_size)]
|
21 |
+
temp_testoutput = []
|
22 |
+
with open(infname, "r", encoding="utf-8") as infile:
|
23 |
+
temp_testoutput = list(
|
24 |
+
map(
|
25 |
+
lambda x: x.strip().split("\t"),
|
26 |
+
filter(lambda x: x.startswith("H-"), infile),
|
27 |
+
)
|
28 |
+
)
|
29 |
+
temp_testoutput = list(
|
30 |
+
map(lambda x: (int(x[0].split("-")[1]), float(x[1]), x[2]), temp_testoutput)
|
31 |
+
)
|
32 |
+
for sid, score, hyp in temp_testoutput:
|
33 |
+
consolidated_testoutput[sid] = (sid, score, hyp)
|
34 |
+
#consolidated_testoutput = [x[2] for x in consolidated_testoutput]
|
35 |
+
|
36 |
+
with open(outfname, "w", encoding="utf-8") as outfile:
|
37 |
+
for (sid, score, hyp) in consolidated_testoutput:
|
38 |
+
outfile.write("{}\n".format(score))
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
|
42 |
+
infname = sys.argv[1]
|
43 |
+
outfname = sys.argv[2]
|
44 |
+
input_size = int(sys.argv[3])
|
45 |
+
|
46 |
+
postprocess(
|
47 |
+
infname, outfname, input_size
|
48 |
+
)
|
scripts/postprocess_translate.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
INDIC_NLP_LIB_HOME = "indic_nlp_library"
|
2 |
+
INDIC_NLP_RESOURCES = "indic_nlp_resources"
|
3 |
+
import sys
|
4 |
+
|
5 |
+
from indicnlp import transliterate
|
6 |
+
|
7 |
+
sys.path.append(r"{}".format(INDIC_NLP_LIB_HOME))
|
8 |
+
from indicnlp import common
|
9 |
+
|
10 |
+
common.set_resources_path(INDIC_NLP_RESOURCES)
|
11 |
+
from indicnlp import loader
|
12 |
+
|
13 |
+
loader.load()
|
14 |
+
from sacremoses import MosesPunctNormalizer
|
15 |
+
from sacremoses import MosesTokenizer
|
16 |
+
from sacremoses import MosesDetokenizer
|
17 |
+
from collections import defaultdict
|
18 |
+
|
19 |
+
import indicnlp
|
20 |
+
from indicnlp.tokenize import indic_tokenize
|
21 |
+
from indicnlp.tokenize import indic_detokenize
|
22 |
+
from indicnlp.normalize import indic_normalize
|
23 |
+
from indicnlp.transliterate import unicode_transliterate
|
24 |
+
|
25 |
+
|
26 |
+
def postprocess(
|
27 |
+
infname, outfname, input_size, lang, common_lang="hi", transliterate=False
|
28 |
+
):
|
29 |
+
"""
|
30 |
+
parse fairseq interactive output, convert script back to native Indic script (in case of Indic languages) and detokenize.
|
31 |
+
|
32 |
+
infname: fairseq log file
|
33 |
+
outfname: output file of translation (sentences not translated contain the dummy string 'DUMMY_OUTPUT'
|
34 |
+
input_size: expected number of output sentences
|
35 |
+
lang: language
|
36 |
+
"""
|
37 |
+
|
38 |
+
consolidated_testoutput = []
|
39 |
+
# with open(infname,'r',encoding='utf-8') as infile:
|
40 |
+
# consolidated_testoutput= list(map(lambda x: x.strip(), filter(lambda x: x.startswith('H-'),infile) ))
|
41 |
+
# consolidated_testoutput.sort(key=lambda x: int(x.split('\t')[0].split('-')[1]))
|
42 |
+
# consolidated_testoutput=[ x.split('\t')[2] for x in consolidated_testoutput ]
|
43 |
+
|
44 |
+
consolidated_testoutput = [(x, 0.0, "") for x in range(input_size)]
|
45 |
+
temp_testoutput = []
|
46 |
+
with open(infname, "r", encoding="utf-8") as infile:
|
47 |
+
temp_testoutput = list(
|
48 |
+
map(
|
49 |
+
lambda x: x.strip().split("\t"),
|
50 |
+
filter(lambda x: x.startswith("H-"), infile),
|
51 |
+
)
|
52 |
+
)
|
53 |
+
temp_testoutput = list(
|
54 |
+
map(lambda x: (int(x[0].split("-")[1]), float(x[1]), x[2]), temp_testoutput)
|
55 |
+
)
|
56 |
+
for sid, score, hyp in temp_testoutput:
|
57 |
+
consolidated_testoutput[sid] = (sid, score, hyp)
|
58 |
+
consolidated_testoutput = [x[2] for x in consolidated_testoutput]
|
59 |
+
|
60 |
+
if lang == "en":
|
61 |
+
en_detok = MosesDetokenizer(lang="en")
|
62 |
+
with open(outfname, "w", encoding="utf-8") as outfile:
|
63 |
+
for sent in consolidated_testoutput:
|
64 |
+
outfile.write(en_detok.detokenize(sent.split(" ")) + "\n")
|
65 |
+
else:
|
66 |
+
xliterator = unicode_transliterate.UnicodeIndicTransliterator()
|
67 |
+
with open(outfname, "w", encoding="utf-8") as outfile:
|
68 |
+
for sent in consolidated_testoutput:
|
69 |
+
if transliterate:
|
70 |
+
outstr = indic_detokenize.trivial_detokenize(
|
71 |
+
xliterator.transliterate(sent, common_lang, lang), lang
|
72 |
+
)
|
73 |
+
else:
|
74 |
+
outstr = indic_detokenize.trivial_detokenize(sent, lang)
|
75 |
+
outfile.write(outstr + "\n")
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
# # The path to the local git repo for Indic NLP library
|
80 |
+
# INDIC_NLP_LIB_HOME="indic_nlp_library"
|
81 |
+
# INDIC_NLP_RESOURCES = "indic_nlp_resources"
|
82 |
+
# sys.path.append('{}'.format(INDIC_NLP_LIB_HOME))
|
83 |
+
# common.set_resources_path(INDIC_NLP_RESOURCES)
|
84 |
+
# # The path to the local git repo for Indic NLP Resources
|
85 |
+
# INDIC_NLP_RESOURCES=""
|
86 |
+
|
87 |
+
# sys.path.append('{}'.format(INDIC_NLP_LIB_HOME))
|
88 |
+
# common.set_resources_path(INDIC_NLP_RESOURCES)
|
89 |
+
|
90 |
+
# loader.load()
|
91 |
+
|
92 |
+
infname = sys.argv[1]
|
93 |
+
outfname = sys.argv[2]
|
94 |
+
input_size = int(sys.argv[3])
|
95 |
+
lang = sys.argv[4]
|
96 |
+
if len(sys.argv) == 5:
|
97 |
+
transliterate = False
|
98 |
+
elif len(sys.argv) == 6:
|
99 |
+
transliterate = sys.argv[5]
|
100 |
+
if transliterate.lower() == "true":
|
101 |
+
transliterate = True
|
102 |
+
else:
|
103 |
+
transliterate = False
|
104 |
+
else:
|
105 |
+
print(f"Invalid arguments: {sys.argv}")
|
106 |
+
exit()
|
107 |
+
|
108 |
+
postprocess(
|
109 |
+
infname, outfname, input_size, lang, common_lang="hi", transliterate=transliterate
|
110 |
+
)
|
scripts/preprocess_translate.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
INDIC_NLP_LIB_HOME = "indic_nlp_library"
|
2 |
+
INDIC_NLP_RESOURCES = "indic_nlp_resources"
|
3 |
+
import sys
|
4 |
+
|
5 |
+
sys.path.append(r"{}".format(INDIC_NLP_LIB_HOME))
|
6 |
+
from indicnlp import common
|
7 |
+
|
8 |
+
common.set_resources_path(INDIC_NLP_RESOURCES)
|
9 |
+
from indicnlp import loader
|
10 |
+
|
11 |
+
loader.load()
|
12 |
+
from sacremoses import MosesPunctNormalizer
|
13 |
+
from sacremoses import MosesTokenizer
|
14 |
+
from sacremoses import MosesDetokenizer
|
15 |
+
from collections import defaultdict
|
16 |
+
|
17 |
+
from tqdm import tqdm
|
18 |
+
from joblib import Parallel, delayed
|
19 |
+
|
20 |
+
from indicnlp.tokenize import indic_tokenize
|
21 |
+
from indicnlp.tokenize import indic_detokenize
|
22 |
+
from indicnlp.normalize import indic_normalize
|
23 |
+
from indicnlp.transliterate import unicode_transliterate
|
24 |
+
|
25 |
+
|
26 |
+
en_tok = MosesTokenizer(lang="en")
|
27 |
+
en_normalizer = MosesPunctNormalizer()
|
28 |
+
|
29 |
+
|
30 |
+
def preprocess_line(line, normalizer, lang, transliterate=False):
|
31 |
+
if lang == "en":
|
32 |
+
return " ".join(
|
33 |
+
en_tok.tokenize(en_normalizer.normalize(line.strip()), escape=False)
|
34 |
+
)
|
35 |
+
elif transliterate:
|
36 |
+
# line = indic_detokenize.trivial_detokenize(line.strip(), lang)
|
37 |
+
return unicode_transliterate.UnicodeIndicTransliterator.transliterate(
|
38 |
+
" ".join(
|
39 |
+
indic_tokenize.trivial_tokenize(
|
40 |
+
normalizer.normalize(line.strip()), lang
|
41 |
+
)
|
42 |
+
),
|
43 |
+
lang,
|
44 |
+
"hi",
|
45 |
+
).replace(" ् ", "्")
|
46 |
+
else:
|
47 |
+
# we only need to transliterate for joint training
|
48 |
+
return " ".join(
|
49 |
+
indic_tokenize.trivial_tokenize(normalizer.normalize(line.strip()), lang)
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
def preprocess(infname, outfname, lang, transliterate=False):
|
54 |
+
"""
|
55 |
+
Normalize, tokenize and script convert(for Indic)
|
56 |
+
return number of sentences input file
|
57 |
+
|
58 |
+
"""
|
59 |
+
|
60 |
+
n = 0
|
61 |
+
num_lines = sum(1 for line in open(infname, "r"))
|
62 |
+
if lang == "en":
|
63 |
+
with open(infname, "r", encoding="utf-8") as infile, open(
|
64 |
+
outfname, "w", encoding="utf-8"
|
65 |
+
) as outfile:
|
66 |
+
|
67 |
+
out_lines = Parallel(n_jobs=-1, backend="multiprocessing")(
|
68 |
+
delayed(preprocess_line)(line, None, lang)
|
69 |
+
for line in tqdm(infile, total=num_lines)
|
70 |
+
)
|
71 |
+
|
72 |
+
for line in out_lines:
|
73 |
+
outfile.write(line + "\n")
|
74 |
+
n += 1
|
75 |
+
|
76 |
+
else:
|
77 |
+
normfactory = indic_normalize.IndicNormalizerFactory()
|
78 |
+
normalizer = normfactory.get_normalizer(lang)
|
79 |
+
# reading
|
80 |
+
with open(infname, "r", encoding="utf-8") as infile, open(
|
81 |
+
outfname, "w", encoding="utf-8"
|
82 |
+
) as outfile:
|
83 |
+
|
84 |
+
out_lines = Parallel(n_jobs=-1, backend="multiprocessing")(
|
85 |
+
delayed(preprocess_line)(line, normalizer, lang, transliterate)
|
86 |
+
for line in tqdm(infile, total=num_lines)
|
87 |
+
)
|
88 |
+
|
89 |
+
for line in out_lines:
|
90 |
+
outfile.write(line + "\n")
|
91 |
+
n += 1
|
92 |
+
return n
|
93 |
+
|
94 |
+
|
95 |
+
def old_preprocess(infname, outfname, lang):
|
96 |
+
"""
|
97 |
+
Preparing each corpus file:
|
98 |
+
- Normalization
|
99 |
+
- Tokenization
|
100 |
+
- Script coversion to Devanagari for Indic scripts
|
101 |
+
"""
|
102 |
+
n = 0
|
103 |
+
num_lines = sum(1 for line in open(infname, "r"))
|
104 |
+
# reading
|
105 |
+
with open(infname, "r", encoding="utf-8") as infile, open(
|
106 |
+
outfname, "w", encoding="utf-8"
|
107 |
+
) as outfile:
|
108 |
+
|
109 |
+
if lang == "en":
|
110 |
+
en_tok = MosesTokenizer(lang="en")
|
111 |
+
en_normalizer = MosesPunctNormalizer()
|
112 |
+
for line in tqdm(infile, total=num_lines):
|
113 |
+
outline = " ".join(
|
114 |
+
en_tok.tokenize(en_normalizer.normalize(line.strip()), escape=False)
|
115 |
+
)
|
116 |
+
outfile.write(outline + "\n")
|
117 |
+
n += 1
|
118 |
+
|
119 |
+
else:
|
120 |
+
normfactory = indic_normalize.IndicNormalizerFactory()
|
121 |
+
normalizer = normfactory.get_normalizer(lang)
|
122 |
+
for line in tqdm(infile, total=num_lines):
|
123 |
+
outline = (
|
124 |
+
unicode_transliterate.UnicodeIndicTransliterator.transliterate(
|
125 |
+
" ".join(
|
126 |
+
indic_tokenize.trivial_tokenize(
|
127 |
+
normalizer.normalize(line.strip()), lang
|
128 |
+
)
|
129 |
+
),
|
130 |
+
lang,
|
131 |
+
"hi",
|
132 |
+
).replace(" ् ", "्")
|
133 |
+
)
|
134 |
+
|
135 |
+
outfile.write(outline + "\n")
|
136 |
+
n += 1
|
137 |
+
return n
|
138 |
+
|
139 |
+
|
140 |
+
if __name__ == "__main__":
|
141 |
+
|
142 |
+
# INDIC_NLP_LIB_HOME = "indic_nlp_library"
|
143 |
+
# INDIC_NLP_RESOURCES = "indic_nlp_resources"
|
144 |
+
# sys.path.append(r'{}'.format(INDIC_NLP_LIB_HOME))
|
145 |
+
# common.set_resources_path(INDIC_NLP_RESOURCES)
|
146 |
+
|
147 |
+
# data_dir = '../joint_training/v1'
|
148 |
+
# new_dir = data_dir + '.norm'
|
149 |
+
# for path, subdirs, files in os.walk(data_dir):
|
150 |
+
# for name in files:
|
151 |
+
# infile = os.path.join(path, name)
|
152 |
+
# lang = infile.split('.')[-1]
|
153 |
+
# outfile = os.path.join(path.replace(data_dir, new_dir), name)
|
154 |
+
# preprocess(infile, outfile, lang)
|
155 |
+
# loader.load()
|
156 |
+
|
157 |
+
infname = sys.argv[1]
|
158 |
+
outfname = sys.argv[2]
|
159 |
+
lang = sys.argv[3]
|
160 |
+
|
161 |
+
if len(sys.argv) == 4:
|
162 |
+
transliterate = False
|
163 |
+
elif len(sys.argv) == 5:
|
164 |
+
transliterate = sys.argv[4]
|
165 |
+
if transliterate.lower() == "true":
|
166 |
+
transliterate = True
|
167 |
+
else:
|
168 |
+
transliterate = False
|
169 |
+
else:
|
170 |
+
print(f"Invalid arguments: {sys.argv}")
|
171 |
+
exit()
|
172 |
+
print(preprocess(infname, outfname, lang, transliterate))
|
scripts/remove_large_sentences.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
import sys
|
3 |
+
|
4 |
+
|
5 |
+
def remove_large_sentences(src_path, tgt_path):
|
6 |
+
count = 0
|
7 |
+
new_src_lines = []
|
8 |
+
new_tgt_lines = []
|
9 |
+
src_num_lines = sum(1 for line in open(src_path, "r", encoding="utf-8"))
|
10 |
+
tgt_num_lines = sum(1 for line in open(tgt_path, "r", encoding="utf-8"))
|
11 |
+
assert src_num_lines == tgt_num_lines
|
12 |
+
with open(src_path, encoding="utf-8") as f1, open(tgt_path, encoding="utf-8") as f2:
|
13 |
+
for src_line, tgt_line in tqdm(zip(f1, f2), total=src_num_lines):
|
14 |
+
src_tokens = src_line.strip().split(" ")
|
15 |
+
tgt_tokens = tgt_line.strip().split(" ")
|
16 |
+
if len(src_tokens) > 200 or len(tgt_tokens) > 200:
|
17 |
+
count += 1
|
18 |
+
continue
|
19 |
+
new_src_lines.append(src_line)
|
20 |
+
new_tgt_lines.append(tgt_line)
|
21 |
+
return count, new_src_lines, new_tgt_lines
|
22 |
+
|
23 |
+
|
24 |
+
def create_txt(outFile, lines, add_newline=False):
|
25 |
+
outfile = open("{0}".format(outFile), "w", encoding="utf-8")
|
26 |
+
for line in lines:
|
27 |
+
if add_newline:
|
28 |
+
outfile.write(line + "\n")
|
29 |
+
else:
|
30 |
+
outfile.write(line)
|
31 |
+
outfile.close()
|
32 |
+
|
33 |
+
|
34 |
+
if __name__ == "__main__":
|
35 |
+
|
36 |
+
src_path = sys.argv[1]
|
37 |
+
tgt_path = sys.argv[2]
|
38 |
+
new_src_path = sys.argv[3]
|
39 |
+
new_tgt_path = sys.argv[4]
|
40 |
+
|
41 |
+
count, new_src_lines, new_tgt_lines = remove_large_sentences(src_path, tgt_path)
|
42 |
+
print(f'{count} lines removed due to seq_len > 200')
|
43 |
+
create_txt(new_src_path, new_src_lines)
|
44 |
+
create_txt(new_tgt_path, new_tgt_lines)
|
scripts/remove_train_devtest_overlaps.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import string
|
3 |
+
import shutil
|
4 |
+
from itertools import permutations, chain
|
5 |
+
from collections import defaultdict
|
6 |
+
from tqdm import tqdm
|
7 |
+
import sys
|
8 |
+
|
9 |
+
INDIC_LANGS = ["as", "bn", "gu", "hi", "kn", "ml", "mr", "or", "pa", "ta", "te"]
|
10 |
+
# we will be testing the overlaps of training data with all these benchmarks
|
11 |
+
# benchmarks = ['wat2021-devtest', 'wat2020-devtest', 'wat-2018', 'wmt-news', 'ufal-ta', 'pmi']
|
12 |
+
|
13 |
+
|
14 |
+
def read_lines(path):
|
15 |
+
# if path doesnt exist, return empty list
|
16 |
+
if not os.path.exists(path):
|
17 |
+
return []
|
18 |
+
with open(path, "r") as f:
|
19 |
+
lines = f.readlines()
|
20 |
+
return lines
|
21 |
+
|
22 |
+
|
23 |
+
def create_txt(outFile, lines):
|
24 |
+
add_newline = not "\n" in lines[0]
|
25 |
+
outfile = open("{0}".format(outFile), "w")
|
26 |
+
for line in lines:
|
27 |
+
if add_newline:
|
28 |
+
outfile.write(line + "\n")
|
29 |
+
else:
|
30 |
+
outfile.write(line)
|
31 |
+
|
32 |
+
outfile.close()
|
33 |
+
|
34 |
+
|
35 |
+
def pair_dedup_files(src_file, tgt_file):
|
36 |
+
src_lines = read_lines(src_file)
|
37 |
+
tgt_lines = read_lines(tgt_file)
|
38 |
+
len_before = len(src_lines)
|
39 |
+
|
40 |
+
src_dedupped, tgt_dedupped = pair_dedup_lists(src_lines, tgt_lines)
|
41 |
+
|
42 |
+
len_after = len(src_dedupped)
|
43 |
+
num_duplicates = len_before - len_after
|
44 |
+
|
45 |
+
print(f"Dropped duplicate pairs in {src_file} Num duplicates -> {num_duplicates}")
|
46 |
+
create_txt(src_file, src_dedupped)
|
47 |
+
create_txt(tgt_file, tgt_dedupped)
|
48 |
+
|
49 |
+
|
50 |
+
def pair_dedup_lists(src_list, tgt_list):
|
51 |
+
src_tgt = list(set(zip(src_list, tgt_list)))
|
52 |
+
src_deduped, tgt_deduped = zip(*src_tgt)
|
53 |
+
return src_deduped, tgt_deduped
|
54 |
+
|
55 |
+
|
56 |
+
def strip_and_normalize(line):
|
57 |
+
# lowercase line, remove spaces and strip punctuation
|
58 |
+
|
59 |
+
# one of the fastest way to add an exclusion list and remove that
|
60 |
+
# list of characters from a string
|
61 |
+
# https://towardsdatascience.com/how-to-efficiently-remove-punctuations-from-a-string-899ad4a059fb
|
62 |
+
exclist = string.punctuation + "\u0964"
|
63 |
+
table_ = str.maketrans("", "", exclist)
|
64 |
+
|
65 |
+
line = line.replace(" ", "").lower()
|
66 |
+
# dont use this method, it is painfully slow
|
67 |
+
# line = "".join([i for i in line if i not in string.punctuation])
|
68 |
+
line = line.translate(table_)
|
69 |
+
return line
|
70 |
+
|
71 |
+
|
72 |
+
def expand_tupled_list(list_of_tuples):
|
73 |
+
# convert list of tuples into two lists
|
74 |
+
# https://stackoverflow.com/questions/8081545/how-to-convert-list-of-tuples-to-multiple-lists
|
75 |
+
# [(en, as), (as, bn), (bn, gu)] - > [en, as, bn], [as, bn, gu]
|
76 |
+
list_a, list_b = map(list, zip(*list_of_tuples))
|
77 |
+
return list_a, list_b
|
78 |
+
|
79 |
+
|
80 |
+
def get_src_tgt_lang_lists(many2many=False):
|
81 |
+
if many2many is False:
|
82 |
+
SRC_LANGS = ["en"]
|
83 |
+
TGT_LANGS = INDIC_LANGS
|
84 |
+
else:
|
85 |
+
all_languages = INDIC_LANGS + ["en"]
|
86 |
+
# lang_pairs = list(permutations(all_languages, 2))
|
87 |
+
|
88 |
+
SRC_LANGS, TGT_LANGS = all_languages, all_languages
|
89 |
+
|
90 |
+
return SRC_LANGS, TGT_LANGS
|
91 |
+
|
92 |
+
|
93 |
+
def normalize_and_gather_all_benchmarks(devtest_dir, many2many=False):
|
94 |
+
|
95 |
+
# This is a dict of dict of lists
|
96 |
+
# the first keys are for lang-pair, the second keys are for src/tgt
|
97 |
+
# the values are the devtest lines.
|
98 |
+
# so devtest_pairs_normalized[en-as][src] will store src(en lines)
|
99 |
+
# so devtest_pairs_normalized[en-as][tgt] will store tgt(as lines)
|
100 |
+
devtest_pairs_normalized = defaultdict(lambda: defaultdict(list))
|
101 |
+
SRC_LANGS, TGT_LANGS = get_src_tgt_lang_lists(many2many)
|
102 |
+
benchmarks = os.listdir(devtest_dir)
|
103 |
+
for dataset in benchmarks:
|
104 |
+
for src_lang in SRC_LANGS:
|
105 |
+
for tgt_lang in TGT_LANGS:
|
106 |
+
if src_lang == tgt_lang:
|
107 |
+
continue
|
108 |
+
if dataset == "wat2021-devtest":
|
109 |
+
# wat2021 dev and test sets have differnet folder structure
|
110 |
+
src_dev = read_lines(f"{devtest_dir}/{dataset}/dev.{src_lang}")
|
111 |
+
tgt_dev = read_lines(f"{devtest_dir}/{dataset}/dev.{tgt_lang}")
|
112 |
+
src_test = read_lines(f"{devtest_dir}/{dataset}/test.{src_lang}")
|
113 |
+
tgt_test = read_lines(f"{devtest_dir}/{dataset}/test.{tgt_lang}")
|
114 |
+
else:
|
115 |
+
src_dev = read_lines(
|
116 |
+
f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/dev.{src_lang}"
|
117 |
+
)
|
118 |
+
tgt_dev = read_lines(
|
119 |
+
f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/dev.{tgt_lang}"
|
120 |
+
)
|
121 |
+
src_test = read_lines(
|
122 |
+
f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/test.{src_lang}"
|
123 |
+
)
|
124 |
+
tgt_test = read_lines(
|
125 |
+
f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/test.{tgt_lang}"
|
126 |
+
)
|
127 |
+
|
128 |
+
# if the tgt_pair data doesnt exist for a particular test set,
|
129 |
+
# it will be an empty list
|
130 |
+
if tgt_test == [] or tgt_dev == []:
|
131 |
+
# print(f'{dataset} does not have {src_lang}-{tgt_lang} data')
|
132 |
+
continue
|
133 |
+
|
134 |
+
# combine both dev and test sets into one
|
135 |
+
src_devtest = src_dev + src_test
|
136 |
+
tgt_devtest = tgt_dev + tgt_test
|
137 |
+
|
138 |
+
src_devtest = [strip_and_normalize(line) for line in src_devtest]
|
139 |
+
tgt_devtest = [strip_and_normalize(line) for line in tgt_devtest]
|
140 |
+
|
141 |
+
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["src"].extend(
|
142 |
+
src_devtest
|
143 |
+
)
|
144 |
+
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["tgt"].extend(
|
145 |
+
tgt_devtest
|
146 |
+
)
|
147 |
+
|
148 |
+
# dedup merged benchmark datasets
|
149 |
+
for src_lang in SRC_LANGS:
|
150 |
+
for tgt_lang in TGT_LANGS:
|
151 |
+
if src_lang == tgt_lang:
|
152 |
+
continue
|
153 |
+
src_devtest, tgt_devtest = (
|
154 |
+
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["src"],
|
155 |
+
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["tgt"],
|
156 |
+
)
|
157 |
+
# if the devtest data doesnt exist for the src-tgt pair then continue
|
158 |
+
if src_devtest == [] or tgt_devtest == []:
|
159 |
+
continue
|
160 |
+
src_devtest, tgt_devtest = pair_dedup_lists(src_devtest, tgt_devtest)
|
161 |
+
(
|
162 |
+
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["src"],
|
163 |
+
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["tgt"],
|
164 |
+
) = (
|
165 |
+
src_devtest,
|
166 |
+
tgt_devtest,
|
167 |
+
)
|
168 |
+
|
169 |
+
return devtest_pairs_normalized
|
170 |
+
|
171 |
+
|
172 |
+
def remove_train_devtest_overlaps(train_dir, devtest_dir, many2many=False):
|
173 |
+
|
174 |
+
devtest_pairs_normalized = normalize_and_gather_all_benchmarks(
|
175 |
+
devtest_dir, many2many
|
176 |
+
)
|
177 |
+
|
178 |
+
SRC_LANGS, TGT_LANGS = get_src_tgt_lang_lists(many2many)
|
179 |
+
|
180 |
+
if not many2many:
|
181 |
+
all_src_sentences_normalized = []
|
182 |
+
for key in devtest_pairs_normalized:
|
183 |
+
all_src_sentences_normalized.extend(devtest_pairs_normalized[key]["src"])
|
184 |
+
# remove all duplicates. Now this contains all the normalized
|
185 |
+
# english sentences in all test benchmarks across all lang pair
|
186 |
+
all_src_sentences_normalized = list(set(all_src_sentences_normalized))
|
187 |
+
else:
|
188 |
+
all_src_sentences_normalized = None
|
189 |
+
|
190 |
+
src_overlaps = []
|
191 |
+
tgt_overlaps = []
|
192 |
+
for src_lang in SRC_LANGS:
|
193 |
+
for tgt_lang in TGT_LANGS:
|
194 |
+
if src_lang == tgt_lang:
|
195 |
+
continue
|
196 |
+
new_src_train = []
|
197 |
+
new_tgt_train = []
|
198 |
+
|
199 |
+
pair = f"{src_lang}-{tgt_lang}"
|
200 |
+
src_train = read_lines(f"{train_dir}/{pair}/train.{src_lang}")
|
201 |
+
tgt_train = read_lines(f"{train_dir}/{pair}/train.{tgt_lang}")
|
202 |
+
|
203 |
+
len_before = len(src_train)
|
204 |
+
if len_before == 0:
|
205 |
+
continue
|
206 |
+
|
207 |
+
src_train_normalized = [strip_and_normalize(line) for line in src_train]
|
208 |
+
tgt_train_normalized = [strip_and_normalize(line) for line in tgt_train]
|
209 |
+
|
210 |
+
if all_src_sentences_normalized:
|
211 |
+
src_devtest_normalized = all_src_sentences_normalized
|
212 |
+
else:
|
213 |
+
src_devtest_normalized = devtest_pairs_normalized[pair]["src"]
|
214 |
+
|
215 |
+
tgt_devtest_normalized = devtest_pairs_normalized[pair]["tgt"]
|
216 |
+
|
217 |
+
# compute all src and tgt super strict overlaps for a lang pair
|
218 |
+
overlaps = set(src_train_normalized) & set(src_devtest_normalized)
|
219 |
+
src_overlaps.extend(list(overlaps))
|
220 |
+
|
221 |
+
overlaps = set(tgt_train_normalized) & set(tgt_devtest_normalized)
|
222 |
+
tgt_overlaps.extend(list(overlaps))
|
223 |
+
# dictionaries offer o(1) lookup
|
224 |
+
src_overlaps_dict = {}
|
225 |
+
tgt_overlaps_dict = {}
|
226 |
+
for line in src_overlaps:
|
227 |
+
src_overlaps_dict[line] = 1
|
228 |
+
for line in tgt_overlaps:
|
229 |
+
tgt_overlaps_dict[line] = 1
|
230 |
+
|
231 |
+
# loop to remove the ovelapped data
|
232 |
+
idx = -1
|
233 |
+
for src_line_norm, tgt_line_norm in tqdm(
|
234 |
+
zip(src_train_normalized, tgt_train_normalized), total=len_before
|
235 |
+
):
|
236 |
+
idx += 1
|
237 |
+
if src_overlaps_dict.get(src_line_norm, None):
|
238 |
+
continue
|
239 |
+
if tgt_overlaps_dict.get(tgt_line_norm, None):
|
240 |
+
continue
|
241 |
+
new_src_train.append(src_train[idx])
|
242 |
+
new_tgt_train.append(tgt_train[idx])
|
243 |
+
|
244 |
+
len_after = len(new_src_train)
|
245 |
+
print(
|
246 |
+
f"Detected overlaps between train and devetest for {pair} is {len_before - len_after}"
|
247 |
+
)
|
248 |
+
print(f"saving new files at {train_dir}/{pair}/")
|
249 |
+
create_txt(f"{train_dir}/{pair}/train.{src_lang}", new_src_train)
|
250 |
+
create_txt(f"{train_dir}/{pair}/train.{tgt_lang}", new_tgt_train)
|
251 |
+
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
train_data_dir = sys.argv[1]
|
255 |
+
# benchmarks directory should contains all the test sets
|
256 |
+
devtest_data_dir = sys.argv[2]
|
257 |
+
if len(sys.argv) == 3:
|
258 |
+
many2many = False
|
259 |
+
elif len(sys.argv) == 4:
|
260 |
+
many2many = sys.argv[4]
|
261 |
+
if many2many.lower() == "true":
|
262 |
+
many2many = True
|
263 |
+
else:
|
264 |
+
many2many = False
|
265 |
+
remove_train_devtest_overlaps(train_data_dir, devtest_data_dir, many2many)
|