Iker commited on
Commit
5401d1a
1 Parent(s): 6e4adc1

Support for NLLB and Sampling

Browse files
Files changed (3) hide show
  1. README.md +81 -9
  2. supported_languages.md +214 -2
  3. translate.py +54 -7
README.md CHANGED
@@ -13,11 +13,7 @@
13
  <br>
14
  </p>
15
 
16
- Easy-Translate is a script for translating large text files in your machine using the [M2M100 models](https://arxiv.org/pdf/2010.11125.pdf) from Facebook/Meta AI. We also privide a [script](#evaluate-translations) for Easy-Evaluation of your translations 🥳
17
-
18
- **M2M100** is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository.
19
-
20
- >M2M100 can directly translate between 9,900 directions of 100 languages.
21
 
22
  Easy-Translate is built on top of 🤗HuggingFace's [Transformers](https://huggingface.co/docs/transformers/index) and 🤗HuggingFace's [Accelerate](https://huggingface.co/docs/accelerate/index) library.
23
 
@@ -27,26 +23,43 @@ We currently support:
27
  - BF16 / FP16 / FP32 precision.
28
  - Automatic batch size finder: Forget CUDA OOM errors. Set an initial batch size, if it doesn't fit, we will automatically adjust it.
29
  - Sharded Data Parallel to load huge models sharded on multiple GPUs (See: <https://huggingface.co/docs/accelerate/fsdp>).
 
30
 
31
  >Test the 🔌 Online Demo here: <https://huggingface.co/spaces/Iker/Translate-100-languages>
32
 
33
 
 
34
  ## Supported languages
35
 
36
  See the [Supported languages table](supported_languages.md) for a table of the supported languages and their ids.
37
 
38
- **List of supported languages:**
39
- Afrikaans, Amharic, Arabic, Asturian, Azerbaijani, Bashkir, Belarusian, Bulgarian, Bengali, Breton, Bosnian, Catalan, Cebuano, Czech, Welsh, Danish, German, Greeek, English, Spanish, Estonian, Persian, Fulah, Finnish, French, WesternFrisian, Irish, Gaelic, Galician, Gujarati, Hausa, Hebrew, Hindi, Croatian, Haitian, Hungarian, Armenian, Indonesian, Igbo, Iloko, Icelandic, Italian, Japanese, Javanese, Georgian, Kazakh, CentralKhmer, Kannada, Korean, Luxembourgish, Ganda, Lingala, Lao, Lithuanian, Latvian, Malagasy, Macedonian, Malayalam, Mongolian, Marathi, Malay, Burmese, Nepali, Dutch, Norwegian, NorthernSotho, Occitan, Oriya, Panjabi, Polish, Pushto, Portuguese, Romanian, Russian, Sindhi, Sinhala, Slovak, Slovenian, Somali, Albanian, Serbian, Swati, Sundanese, Swedish, Swahili, Tamil, Thai, Tagalog, Tswana, Turkish, Ukrainian, Urdu, Uzbek, Vietnamese, Wolof, Xhosa, Yiddish, Yoruba, Chinese, Zulu
40
-
41
  ## Supported Models
42
 
 
 
 
 
43
  - **Facebook/m2m100_418M**: <https://huggingface.co/facebook/m2m100_418M>
44
 
45
  - **Facebook/m2m100_1.2B**: <https://huggingface.co/facebook/m2m100_1.2B>
46
 
47
  - **Facebook/m2m100_12B**: <https://huggingface.co/facebook/m2m100-12B-avg-5-ckpt>
48
 
49
- - Any other m2m100 model from HuggingFace's Hub: <https://huggingface.co/models?search=m2m100>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  ## Requirements
52
 
@@ -59,6 +72,9 @@ pip install --upgrade accelerate
59
 
60
  HuggingFace Transformers
61
  pip install --upgrade transformers
 
 
 
62
  ```
63
 
64
  ## Translate a file
@@ -109,6 +125,62 @@ accelerate launch translate.py \
109
  --precision fp16
110
  ```
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  ## Evaluate translations
113
 
114
  To run the evaluation script you need to install [bert_score](https://github.com/Tiiiger/bert_score): `pip install bert_score` and 🤗HuggingFace's [Datasets](https://huggingface.co/docs/datasets/index) model: `pip install datasets`.
 
13
  <br>
14
  </p>
15
 
16
+ Easy-Translate is a script for translating large text files in your machine using the [M2M100 models](https://arxiv.org/pdf/2010.11125.pdf) and [NLLB200 models](https://research.facebook.com/publications/no-language-left-behind/) from Facebook/Meta AI. We also privide a [script](#evaluate-translations) for Easy-Evaluation of your translations 🥳
 
 
 
 
17
 
18
  Easy-Translate is built on top of 🤗HuggingFace's [Transformers](https://huggingface.co/docs/transformers/index) and 🤗HuggingFace's [Accelerate](https://huggingface.co/docs/accelerate/index) library.
19
 
 
23
  - BF16 / FP16 / FP32 precision.
24
  - Automatic batch size finder: Forget CUDA OOM errors. Set an initial batch size, if it doesn't fit, we will automatically adjust it.
25
  - Sharded Data Parallel to load huge models sharded on multiple GPUs (See: <https://huggingface.co/docs/accelerate/fsdp>).
26
+ - Greedy decoding / Beam Search decoding / Multinomial Sampling / Beam-Search Multinomial Sampling
27
 
28
  >Test the 🔌 Online Demo here: <https://huggingface.co/spaces/Iker/Translate-100-languages>
29
 
30
 
31
+
32
  ## Supported languages
33
 
34
  See the [Supported languages table](supported_languages.md) for a table of the supported languages and their ids.
35
 
 
 
 
36
  ## Supported Models
37
 
38
+ ### M2M100
39
+ **M2M100** is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository.
40
+ >M2M100 can directly translate between 9,900 directions of 100 languages.
41
+
42
  - **Facebook/m2m100_418M**: <https://huggingface.co/facebook/m2m100_418M>
43
 
44
  - **Facebook/m2m100_1.2B**: <https://huggingface.co/facebook/m2m100_1.2B>
45
 
46
  - **Facebook/m2m100_12B**: <https://huggingface.co/facebook/m2m100-12B-avg-5-ckpt>
47
 
48
+ ### NLLB200
49
+
50
+ **No Language Left Behind (NLLB)** open-sources models capable of delivering high-quality translations directly between any pair of 200+ languages — including low-resource languages like Asturian, Luganda, Urdu and more. It aims to help people communicate with anyone, anywhere, regardless of their language preferences. It was introduced in this [paper](https://research.facebook.com/publications/no-language-left-behind/) and first released in [this](https://github.com/facebookresearch/fairseq/tree/nllb) repository.
51
+ >NLLB can directly translate between +40,000 of +200 languages.
52
+
53
+ - **facebook/nllb-200-3.3B**: <https://huggingface.co/facebook/nllb-200-3.3B>
54
+
55
+ - **facebook/nllb-200-1.3B**: <https://huggingface.co/facebook/nllb-200-1.3B>
56
+
57
+ - **facebook/nllb-200-distilled-1.3B**: <https://huggingface.co/facebook/nllb-200-distilled-1.3B>
58
+
59
+ - **facebook/nllb-200-distilled-600M**: <https://huggingface.co/facebook/nllb-200-distilled-600M>
60
+
61
+
62
+ Any other ModelForSeq2SeqLM from HuggingFace's Hub should work with this library: <https://huggingface.co/models?pipeline_tag=text2text-generation>
63
 
64
  ## Requirements
65
 
 
72
 
73
  HuggingFace Transformers
74
  pip install --upgrade transformers
75
+
76
+ If you find errors using NLLB200, try installing transformers from source:
77
+ pip install git+https://github.com/huggingface/transformers.git
78
  ```
79
 
80
  ## Translate a file
 
125
  --precision fp16
126
  ```
127
 
128
+ ### Decoding/Sampling strategies
129
+
130
+ You can choose the decoding/sampling strategy to use and the number of candidate translation to output for each input sentence. By default we will use beam-search with 'num_beams' set to 5, and we will output the most likely candidate translation. But you can change this behavior:
131
+ ##### Greedy decoding
132
+ ```bash
133
+ accelerate launch translate.py \
134
+ --sentences_path sample_text/en.txt \
135
+ --output_path sample_text/en2es.translation.m2m100_1.2B.txt \
136
+ --source_lang en \
137
+ --target_lang es \
138
+ --model_name facebook/m2m100_1.2B \
139
+ --num_beams 1
140
+ ```
141
+
142
+ ##### Multinomial Sampling
143
+ ```bash
144
+ accelerate launch translate.py \
145
+ --sentences_path sample_text/en.txt \
146
+ --output_path sample_text/en2es.translation.m2m100_1.2B.txt \
147
+ --source_lang en \
148
+ --target_lang es \
149
+ --model_name facebook/m2m100_1.2B \
150
+ --num_beams 1 \
151
+ --do_sample \
152
+ --temperature 0.5 \
153
+ --top_k 100 \
154
+ --top_p 0.8 \
155
+ --num_return_sequences 1
156
+ ```
157
+ ##### Beam-Search decoding **(DEFAULT)**
158
+ ```bash
159
+ accelerate launch translate.py \
160
+ --sentences_path sample_text/en.txt \
161
+ --output_path sample_text/en2es.translation.m2m100_1.2B.txt \
162
+ --source_lang en \
163
+ --target_lang es \
164
+ --model_name facebook/m2m100_1.2B \
165
+ --num_beams 5 \
166
+ --num_return_sequences 1 \
167
+ ```
168
+ ##### Beam-Search Multinomial Sampling
169
+ ```bash
170
+ accelerate launch translate.py \
171
+ --sentences_path sample_text/en.txt \
172
+ --output_path sample_text/en2es.translation.m2m100_1.2B.txt \
173
+ --source_lang en \
174
+ --target_lang es \
175
+ --model_name facebook/m2m100_1.2B \
176
+ --num_beams 5 \
177
+ --num_return_sequences 1 \
178
+ --do_sample \
179
+ --temperature 0.5 \
180
+ --top_k 100 \
181
+ --top_p 0.8
182
+ ```
183
+
184
  ## Evaluate translations
185
 
186
  To run the evaluation script you need to install [bert_score](https://github.com/Tiiiger/bert_score): `pip install bert_score` and 🤗HuggingFace's [Datasets](https://huggingface.co/docs/datasets/index) model: `pip install datasets`.
supported_languages.md CHANGED
@@ -1,4 +1,10 @@
1
- ## Supported languages
 
 
 
 
 
 
2
 
3
  | Language | Id |
4
  |---|---|
@@ -101,4 +107,210 @@
101
  | Yiddish | yi |
102
  | Yoruba | yo |
103
  | Chinese | zh |
104
- | Zulu | zu |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # List of supported languages
2
+
3
+ ## Index
4
+ * [M2M100 supported languages](#supported-languages-m2m100)
5
+ * [NLLB200 supported languages](#supported-languages-nllb200)
6
+
7
+ ## Supported languages M2M100
8
 
9
  | Language | Id |
10
  |---|---|
 
107
  | Yiddish | yi |
108
  | Yoruba | yo |
109
  | Chinese | zh |
110
+ | Zulu | zu |
111
+
112
+ ## Supported languages NLLB200
113
+ | Language id |
114
+ |-------------|
115
+ | ace_Arab |
116
+ | ace_Latn |
117
+ | acm_Arab |
118
+ | acq_Arab |
119
+ | aeb_Arab |
120
+ | afr_Latn |
121
+ | ajp_Arab |
122
+ | aka_Latn |
123
+ | amh_Ethi |
124
+ | apc_Arab |
125
+ | arb_Arab |
126
+ | ars_Arab |
127
+ | ary_Arab |
128
+ | arz_Arab |
129
+ | asm_Beng |
130
+ | ast_Latn |
131
+ | awa_Deva |
132
+ | ayr_Latn |
133
+ | azb_Arab |
134
+ | azj_Latn |
135
+ | bak_Cyrl |
136
+ | bam_Latn |
137
+ | ban_Latn |
138
+ | bel_Cyrl |
139
+ | bem_Latn |
140
+ | ben_Beng |
141
+ | bho_Deva |
142
+ | bjn_Arab |
143
+ | bjn_Latn |
144
+ | bod_Tibt |
145
+ | bos_Latn |
146
+ | bug_Latn |
147
+ | bul_Cyrl |
148
+ | cat_Latn |
149
+ | ceb_Latn |
150
+ | ces_Latn |
151
+ | cjk_Latn |
152
+ | ckb_Arab |
153
+ | crh_Latn |
154
+ | cym_Latn |
155
+ | dan_Latn |
156
+ | deu_Latn |
157
+ | dik_Latn |
158
+ | dyu_Latn |
159
+ | dzo_Tibt |
160
+ | ell_Grek |
161
+ | eng_Latn |
162
+ | epo_Latn |
163
+ | est_Latn |
164
+ | eus_Latn |
165
+ | ewe_Latn |
166
+ | fao_Latn |
167
+ | pes_Arab |
168
+ | fij_Latn |
169
+ | fin_Latn |
170
+ | fon_Latn |
171
+ | fra_Latn |
172
+ | fur_Latn |
173
+ | fuv_Latn |
174
+ | gla_Latn |
175
+ | gle_Latn |
176
+ | glg_Latn |
177
+ | grn_Latn |
178
+ | guj_Gujr |
179
+ | hat_Latn |
180
+ | hau_Latn |
181
+ | heb_Hebr |
182
+ | hin_Deva |
183
+ | hne_Deva |
184
+ | hrv_Latn |
185
+ | hun_Latn |
186
+ | hye_Armn |
187
+ | ibo_Latn |
188
+ | ilo_Latn |
189
+ | ind_Latn |
190
+ | isl_Latn |
191
+ | ita_Latn |
192
+ | jav_Latn |
193
+ | jpn_Jpan |
194
+ | kab_Latn |
195
+ | kac_Latn |
196
+ | kam_Latn |
197
+ | kan_Knda |
198
+ | kas_Arab |
199
+ | kas_Deva |
200
+ | kat_Geor |
201
+ | knc_Arab |
202
+ | knc_Latn |
203
+ | kaz_Cyrl |
204
+ | kbp_Latn |
205
+ | kea_Latn |
206
+ | khm_Khmr |
207
+ | kik_Latn |
208
+ | kin_Latn |
209
+ | kir_Cyrl |
210
+ | kmb_Latn |
211
+ | kon_Latn |
212
+ | kor_Hang |
213
+ | kmr_Latn |
214
+ | lao_Laoo |
215
+ | lvs_Latn |
216
+ | lij_Latn |
217
+ | lim_Latn |
218
+ | lin_Latn |
219
+ | lit_Latn |
220
+ | lmo_Latn |
221
+ | ltg_Latn |
222
+ | ltz_Latn |
223
+ | lua_Latn |
224
+ | lug_Latn |
225
+ | luo_Latn |
226
+ | lus_Latn |
227
+ | mag_Deva |
228
+ | mai_Deva |
229
+ | mal_Mlym |
230
+ | mar_Deva |
231
+ | min_Latn |
232
+ | mkd_Cyrl |
233
+ | plt_Latn |
234
+ | mlt_Latn |
235
+ | mni_Beng |
236
+ | khk_Cyrl |
237
+ | mos_Latn |
238
+ | mri_Latn |
239
+ | zsm_Latn |
240
+ | mya_Mymr |
241
+ | nld_Latn |
242
+ | nno_Latn |
243
+ | nob_Latn |
244
+ | npi_Deva |
245
+ | nso_Latn |
246
+ | nus_Latn |
247
+ | nya_Latn |
248
+ | oci_Latn |
249
+ | gaz_Latn |
250
+ | ory_Orya |
251
+ | pag_Latn |
252
+ | pan_Guru |
253
+ | pap_Latn |
254
+ | pol_Latn |
255
+ | por_Latn |
256
+ | prs_Arab |
257
+ | pbt_Arab |
258
+ | quy_Latn |
259
+ | ron_Latn |
260
+ | run_Latn |
261
+ | rus_Cyrl |
262
+ | sag_Latn |
263
+ | san_Deva |
264
+ | sat_Beng |
265
+ | scn_Latn |
266
+ | shn_Mymr |
267
+ | sin_Sinh |
268
+ | slk_Latn |
269
+ | slv_Latn |
270
+ | smo_Latn |
271
+ | sna_Latn |
272
+ | snd_Arab |
273
+ | som_Latn |
274
+ | sot_Latn |
275
+ | spa_Latn |
276
+ | als_Latn |
277
+ | srd_Latn |
278
+ | srp_Cyrl |
279
+ | ssw_Latn |
280
+ | sun_Latn |
281
+ | swe_Latn |
282
+ | swh_Latn |
283
+ | szl_Latn |
284
+ | tam_Taml |
285
+ | tat_Cyrl |
286
+ | tel_Telu |
287
+ | tgk_Cyrl |
288
+ | tgl_Latn |
289
+ | tha_Thai |
290
+ | tir_Ethi |
291
+ | taq_Latn |
292
+ | taq_Tfng |
293
+ | tpi_Latn |
294
+ | tsn_Latn |
295
+ | tso_Latn |
296
+ | tuk_Latn |
297
+ | tum_Latn |
298
+ | tur_Latn |
299
+ | twi_Latn |
300
+ | tzm_Tfng |
301
+ | uig_Arab |
302
+ | ukr_Cyrl |
303
+ | umb_Latn |
304
+ | urd_Arab |
305
+ | uzn_Latn |
306
+ | vec_Latn |
307
+ | vie_Latn |
308
+ | war_Latn |
309
+ | wol_Latn |
310
+ | xho_Latn |
311
+ | ydd_Hebr |
312
+ | yor_Latn |
313
+ | yue_Hant |
314
+ | zho_Hans |
315
+ | zho_Hant |
316
+ | zul_Latn |
translate.py CHANGED
@@ -1,6 +1,6 @@
1
  from transformers import (
2
- M2M100ForConditionalGeneration,
3
- M2M100Tokenizer,
4
  PreTrainedTokenizerBase,
5
  DataCollatorForSeq2Seq,
6
  )
@@ -60,6 +60,10 @@ def main(
60
  max_length: int = 128,
61
  num_beams: int = 4,
62
  num_return_sequences: int = 1,
 
 
 
 
63
  ):
64
 
65
  if not os.path.exists(os.path.abspath(os.path.dirname(output_path))):
@@ -70,11 +74,11 @@ def main(
70
  )
71
 
72
  print(f"Loading tokenizer {model_name}...")
73
- tokenizer = M2M100Tokenizer.from_pretrained(
74
  pretrained_model_name_or_path=model_name, cache_dir=cache_dir
75
  )
76
  print(f"Loading model {model_name}...")
77
- model = M2M100ForConditionalGeneration.from_pretrained(
78
  pretrained_model_name_or_path=model_name, cache_dir=cache_dir
79
  )
80
 
@@ -92,12 +96,21 @@ def main(
92
  raise ValueError("Precision not supported. Supported values: 32, fp16, bf16")
93
 
94
  tokenizer.src_lang = source_lang
95
- lang_code_to_idx = tokenizer.lang_code_to_id[target_lang]
 
 
 
 
 
96
 
97
  gen_kwargs = {
98
  "max_length": max_length,
99
  "num_beams": num_beams,
100
  "num_return_sequences": num_return_sequences,
 
 
 
 
101
  }
102
 
103
  # total_lines: int = count_lines(sentences_path)
@@ -114,10 +127,12 @@ def main(
114
  f"Num. Devices: {accelerator.num_processes}\n"
115
  f"Distributed_type: {accelerator.distributed_type}\n"
116
  f"Max length: {max_length}\n"
117
- f"Num beams: {num_beams}\n"
118
  f"Precision: {model.dtype}\n"
119
  f"Model: {model_name}\n"
120
  )
 
 
 
121
 
122
  @find_executable_batch_size(starting_batch_size=starting_batch_size)
123
  def inference(batch_size):
@@ -167,7 +182,8 @@ def main(
167
  if accelerator.is_main_process:
168
  if step == len(data_loader) - 1:
169
  tgt_text = tgt_text[
170
- : len(data_loader.dataset) - samples_seen
 
171
  ]
172
  else:
173
  samples_seen += len(tgt_text)
@@ -262,6 +278,33 @@ if __name__ == "__main__":
262
  help="Precision of the model. bf16, fp16 or 32.",
263
  )
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  args = parser.parse_args()
266
 
267
  main(
@@ -276,4 +319,8 @@ if __name__ == "__main__":
276
  num_beams=args.num_beams,
277
  num_return_sequences=args.num_return_sequences,
278
  precision=args.precision,
 
 
 
 
279
  )
 
1
  from transformers import (
2
+ AutoModelForSeq2SeqLM,
3
+ AutoTokenizer,
4
  PreTrainedTokenizerBase,
5
  DataCollatorForSeq2Seq,
6
  )
 
60
  max_length: int = 128,
61
  num_beams: int = 4,
62
  num_return_sequences: int = 1,
63
+ do_sample: bool = False,
64
+ temperature: float = 1.0,
65
+ top_k: int = 50,
66
+ top_p: float = 1.0,
67
  ):
68
 
69
  if not os.path.exists(os.path.abspath(os.path.dirname(output_path))):
 
74
  )
75
 
76
  print(f"Loading tokenizer {model_name}...")
77
+ tokenizer = AutoTokenizer.from_pretrained(
78
  pretrained_model_name_or_path=model_name, cache_dir=cache_dir
79
  )
80
  print(f"Loading model {model_name}...")
81
+ model = AutoModelForSeq2SeqLM.from_pretrained(
82
  pretrained_model_name_or_path=model_name, cache_dir=cache_dir
83
  )
84
 
 
96
  raise ValueError("Precision not supported. Supported values: 32, fp16, bf16")
97
 
98
  tokenizer.src_lang = source_lang
99
+ try:
100
+ lang_code_to_idx = tokenizer.lang_code_to_id[target_lang]
101
+ except KeyError:
102
+ raise KeyError(
103
+ f"Language {target_lang} not found in tokenizer. Available languages: {tokenizer.lang_code_to_id.keys()}"
104
+ )
105
 
106
  gen_kwargs = {
107
  "max_length": max_length,
108
  "num_beams": num_beams,
109
  "num_return_sequences": num_return_sequences,
110
+ "do_sample": do_sample,
111
+ "temperature": temperature,
112
+ "top_k": top_k,
113
+ "top_p": top_p,
114
  }
115
 
116
  # total_lines: int = count_lines(sentences_path)
 
127
  f"Num. Devices: {accelerator.num_processes}\n"
128
  f"Distributed_type: {accelerator.distributed_type}\n"
129
  f"Max length: {max_length}\n"
 
130
  f"Precision: {model.dtype}\n"
131
  f"Model: {model_name}\n"
132
  )
133
+ print("** Generation parameters **")
134
+ print("\n".join(f"{k}: {v}" for k, v in gen_kwargs.items()))
135
+ print("\n")
136
 
137
  @find_executable_batch_size(starting_batch_size=starting_batch_size)
138
  def inference(batch_size):
 
182
  if accelerator.is_main_process:
183
  if step == len(data_loader) - 1:
184
  tgt_text = tgt_text[
185
+ : len(data_loader.dataset) * num_return_sequences
186
+ - samples_seen
187
  ]
188
  else:
189
  samples_seen += len(tgt_text)
 
278
  help="Precision of the model. bf16, fp16 or 32.",
279
  )
280
 
281
+ parser.add_argument(
282
+ "--do_sample",
283
+ action="store_true",
284
+ help="Use sampling instead of beam search.",
285
+ )
286
+
287
+ parser.add_argument(
288
+ "--temperature",
289
+ type=float,
290
+ default=1.0,
291
+ help="Temperature for sampling, value used only if do_sample is True.",
292
+ )
293
+
294
+ parser.add_argument(
295
+ "--top_k",
296
+ type=int,
297
+ default=50,
298
+ help="If do_sample is True, will sample from the top k most likely tokens.",
299
+ )
300
+
301
+ parser.add_argument(
302
+ "--top_p",
303
+ type=float,
304
+ default=1.0,
305
+ help="If do_sample is True, will sample from the top k most likely tokens.",
306
+ )
307
+
308
  args = parser.parse_args()
309
 
310
  main(
 
319
  num_beams=args.num_beams,
320
  num_return_sequences=args.num_return_sequences,
321
  precision=args.precision,
322
+ do_sample=args.do_sample,
323
+ temperature=args.temperature,
324
+ top_k=args.top_k,
325
+ top_p=args.top_p,
326
  )