Iker commited on
Commit
9dcafee
1 Parent(s): f88323f

Implement SeamlessM4T

Browse files
README.md CHANGED
@@ -1,4 +1,3 @@
1
-
2
  <p align="center">
3
  <br>
4
  <img src="images/title.png" width="900"/>
@@ -29,23 +28,19 @@ We currently support:
29
  - BF16 / FP16 / FP32 / 8 Bits / 4 Bits precision.
30
  - Automatic batch size finder: Forget CUDA OOM errors. Set an initial batch size, if it doesn't fit, we will automatically adjust it.
31
  - Multiple decoding strategies: Greedy Search, Beam Search, Top-K Sampling, Top-p (nucleus) sampling, etc. See [Decoding Strategies](#decodingsampling-strategies) for more information.
32
- - :new: Load huge models in a single GPU with 8-bits / 4-bits quantization and support for splitting the model between GPU and CPU. See [Loading Huge Models](#loading-huge-models) for more information.
33
- - :new: LoRA models support
34
- - :new: Support for any Seq2SeqLM or CausalLM model from HuggingFace's Hub.
35
- - :new: Prompt support! See [Prompting](#prompting) for more information.
 
36
 
37
  >Test the 🔌 Online Demo here: <https://huggingface.co/spaces/Iker/Translate-100-languages>
38
 
39
 
40
-
41
- ## Supported languages
42
-
43
- See the [Supported languages table](supported_languages.md) for a table of the supported languages and their ids.
44
-
45
  ## Supported Models
46
 
47
  💥 EasyTranslate now supports any Seq2SeqLM (m2m100, nllb200, small100, mbart, MarianMT, T5, FlanT5, etc.) and any CausalLM (GPT2, LLaMA, Vicuna, Falcon) model from 🤗 Hugging Face's Hub!!
48
- We still recommend you to use M2M100 or NLLB200 for the best results, but you can experiment with any other MT model, as well as prompting LLMs to generate translations (See [Prompting Section](#prompting) for more details).
49
  You can also see [the examples folder](examples) for examples of how to use EasyTranslate with different models.
50
 
51
  ### M2M100
@@ -73,13 +68,23 @@ You can also see [the examples folder](examples) for examples of how to use Easy
73
 
74
  - **facebook/nllb-200-distilled-600M**: <https://huggingface.co/facebook/nllb-200-distilled-600M>
75
 
 
 
 
 
 
 
 
 
 
 
76
  ### Other MT Models supported
77
  We support every MT model in the 🤗 Hugging Face's Hub. If you find a model that doesn't work, please open an issue for us to fix it or a PR with the fix. This includes, among many others:
78
  - **Small100**: <https://huggingface.co/alirezamsh/small100>
79
  - **Mbart many-to-many / many-to-one**: <https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt>
80
  - **Opus MT**: <https://huggingface.co/Helsinki-NLP/opus-mt-es-en>
81
 
82
-
83
 
84
  ## Citation
85
  If you use this software please cite
@@ -110,6 +115,7 @@ pip install accelerate
110
 
111
  HuggingFace Transformers
112
  If you plan to use NLLB200, please use >= 4.28.0, as an important bug was fixed in this version.
 
113
  pip install --upgrade transformers
114
 
115
  BitsAndBytes (Optional, required for 8-bits / 4-bits quantization)
@@ -135,6 +141,20 @@ python3 translate.py \
135
  --model_name facebook/m2m100_1.2B
136
  ```
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  #### Multi-GPU
139
 
140
  See Accelerate documentation for more information (multi-node, TPU, Sharded model...): <https://huggingface.co/docs/accelerate/index>
 
 
1
  <p align="center">
2
  <br>
3
  <img src="images/title.png" width="900"/>
 
28
  - BF16 / FP16 / FP32 / 8 Bits / 4 Bits precision.
29
  - Automatic batch size finder: Forget CUDA OOM errors. Set an initial batch size, if it doesn't fit, we will automatically adjust it.
30
  - Multiple decoding strategies: Greedy Search, Beam Search, Top-K Sampling, Top-p (nucleus) sampling, etc. See [Decoding Strategies](#decodingsampling-strategies) for more information.
31
+ - Load huge models in a single GPU with 8-bits / 4-bits quantization and support for splitting the model between GPU and CPU. See [Loading Huge Models](#loading-huge-models) for more information.
32
+ - LoRA models support
33
+ - Support for any Seq2SeqLM or CausalLM model from HuggingFace's Hub.
34
+ - Prompt support! See [Prompting](#prompting) for more information.
35
+ - :new: Add support for [SeamlessM4T](https://huggingface.co/docs/transformers/main/en/model_doc/seamless_m4t)!
36
 
37
  >Test the 🔌 Online Demo here: <https://huggingface.co/spaces/Iker/Translate-100-languages>
38
 
39
 
 
 
 
 
 
40
  ## Supported Models
41
 
42
  💥 EasyTranslate now supports any Seq2SeqLM (m2m100, nllb200, small100, mbart, MarianMT, T5, FlanT5, etc.) and any CausalLM (GPT2, LLaMA, Vicuna, Falcon) model from 🤗 Hugging Face's Hub!!
43
+ We still recommend you to use M2M100, NLLB200 or SeamlessM4T for the best results, but you can experiment with any other MT model, as well as prompting LLMs to generate translations (See [Prompting Section](#prompting) for more details).
44
  You can also see [the examples folder](examples) for examples of how to use EasyTranslate with different models.
45
 
46
  ### M2M100
 
68
 
69
  - **facebook/nllb-200-distilled-600M**: <https://huggingface.co/facebook/nllb-200-distilled-600M>
70
 
71
+ ### SeamlessM4T
72
+
73
+ **SeamlessM4T** a collection of models designed to provide high quality translation, allowing people from different linguistic communities to communicate effortlessly through speech and text. It was introduced in this [paper](https://dl.fbaipublicfiles.com/seamless/seamless_m4t_paper.pdf) and first released in [this](https://github.com/facebookresearch/seamless_communication) repository.
74
+ >SeamlessM4T can directly translate between 196 Languages for text input/output.
75
+
76
+ - **facebook/hf-seamless-m4t-medium**: <https://huggingface.co/facebook/hf-seamless-m4t-medium> (Requires transformers 4.35.0)
77
+
78
+ - **facebook/hf-seamless-m4t-large**: <https://huggingface.co/facebook/hf-seamless-m4t-large> (Requires transformers 4.35.0)
79
+
80
+
81
  ### Other MT Models supported
82
  We support every MT model in the 🤗 Hugging Face's Hub. If you find a model that doesn't work, please open an issue for us to fix it or a PR with the fix. This includes, among many others:
83
  - **Small100**: <https://huggingface.co/alirezamsh/small100>
84
  - **Mbart many-to-many / many-to-one**: <https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt>
85
  - **Opus MT**: <https://huggingface.co/Helsinki-NLP/opus-mt-es-en>
86
 
87
+ See the [Supported languages table](supported_languages.md) for a table of the supported languages and their ids.
88
 
89
  ## Citation
90
  If you use this software please cite
 
115
 
116
  HuggingFace Transformers
117
  If you plan to use NLLB200, please use >= 4.28.0, as an important bug was fixed in this version.
118
+ If you plan to use SeamlessM4T, please use >= 4.35.0.
119
  pip install --upgrade transformers
120
 
121
  BitsAndBytes (Optional, required for 8-bits / 4-bits quantization)
 
141
  --model_name facebook/m2m100_1.2B
142
  ```
143
 
144
+ If you want to translate all the files in a directory, use the `--sentences_dir` flag instead of `--sentences_path`.
145
+ ```bash
146
+ # We use --files_extension txt to translate only files with this extension.
147
+ # Use empty string to translate all files in the directory
148
+
149
+ python3 translate.py \
150
+ --sentences_dir sample_text/ \
151
+ --output_path sample_text/translations \
152
+ --files_extension txt \
153
+ --source_lang en \
154
+ --target_lang es \
155
+ --model_name facebook/m2m100_1.2B
156
+ ```
157
+
158
  #### Multi-GPU
159
 
160
  See Accelerate documentation for more information (multi-node, TPU, Sharded model...): <https://huggingface.co/docs/accelerate/index>
model.py CHANGED
@@ -14,8 +14,6 @@ from transformers.models.auto.modeling_auto import (
14
 
15
  from typing import Optional, Tuple
16
 
17
- import os
18
-
19
  import torch
20
 
21
  import json
@@ -27,6 +25,7 @@ def load_model_for_inference(
27
  lora_weights_name_or_path: Optional[str] = None,
28
  torch_dtype: Optional[str] = None,
29
  force_auto_device_map: bool = False,
 
30
  ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
31
  """
32
  Load any Decoder model for inference.
@@ -50,6 +49,8 @@ def load_model_for_inference(
50
  Whether to force the use of the auto device map. If set to True, the model will be split across
51
  GPUs and CPU to fit the model in memory. If set to False, a full copy of the model will be loaded
52
  into each GPU. Defaults to False.
 
 
53
 
54
  Returns:
55
  `Tuple[PreTrainedModel, PreTrainedTokenizerBase]`:
@@ -64,19 +65,8 @@ def load_model_for_inference(
64
 
65
  print(f"Loading model from {weights_path}")
66
 
67
- MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update(
68
- {
69
- "mpt": "MPTForCausalLM",
70
- "RefinedWebModel": "RWForCausalLM",
71
- "RefinedWeb": "RWForCausalLM",
72
- }
73
- ) # MPT and Falcon are not in transformers yet
74
-
75
  config = AutoConfig.from_pretrained(
76
- weights_path,
77
- trust_remote_code=True
78
- if ("mpt" in weights_path or "falcon" in weights_path)
79
- else False,
80
  )
81
 
82
  torch_dtype = (
@@ -84,20 +74,40 @@ def load_model_for_inference(
84
  )
85
 
86
  if "small100" in weights_path:
 
 
 
 
 
 
 
 
87
  print(f"Loading custom small100 tokenizer for utils.tokenization_small100")
88
  from utils.tokenization_small100 import SMALL100Tokenizer as AutoTokenizer
89
  else:
90
  from transformers import AutoTokenizer
91
 
92
  tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
93
- weights_path,
94
- add_eos_token=True,
95
- trust_remote_code=True
96
- if ("mpt" in weights_path or "falcon" in weights_path)
97
- else False,
98
  )
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  quant_args = {}
 
101
  if quantization is not None:
102
  quant_args = (
103
  {"load_in_4bit": True} if quantization == 4 else {"load_in_8bit": True}
@@ -107,16 +117,17 @@ def load_model_for_inference(
107
  load_in_4bit=True,
108
  bnb_4bit_use_double_quant=True,
109
  bnb_4bit_quant_type="nf4",
110
- bnb_4bit_compute_dtype=torch.bfloat16,
 
 
111
  )
112
- torch_dtype = torch.bfloat16
113
 
114
  else:
115
  bnb_config = BitsAndBytesConfig(
116
  load_in_8bit=True,
117
  )
118
  print(
119
- f"Bits and Bytes config: {json.dumps(bnb_config.to_dict(),indent=4,ensure_ascii=False)}"
120
  )
121
  else:
122
  print(f"Loading model with dtype: {torch_dtype}")
@@ -131,6 +142,7 @@ def load_model_for_inference(
131
  device_map="auto" if force_auto_device_map else None,
132
  torch_dtype=torch_dtype,
133
  quantization_config=bnb_config,
 
134
  **quant_args,
135
  )
136
 
@@ -142,9 +154,7 @@ def load_model_for_inference(
142
  pretrained_model_name_or_path=weights_path,
143
  device_map="auto" if force_auto_device_map else None,
144
  torch_dtype=torch_dtype,
145
- trust_remote_code=True
146
- if ("mpt" in weights_path or "falcon" in weights_path)
147
- else False,
148
  quantization_config=bnb_config,
149
  **quant_args,
150
  )
@@ -159,21 +169,6 @@ def load_model_for_inference(
159
  f"CausalLM: {MODEL_FOR_CAUSAL_LM_MAPPING_NAMES}\n"
160
  )
161
 
162
- if tokenizer.pad_token_id is None:
163
- if "<|padding|>" in tokenizer.get_vocab():
164
- # StableLM specific fix
165
- tokenizer.add_special_tokens({"pad_token": "<|padding|>"})
166
- elif tokenizer.unk_token is not None:
167
- print(
168
- "Model does not have a pad token, we will use the unk token as pad token."
169
- )
170
- tokenizer.pad_token_id = tokenizer.unk_token_id
171
- else:
172
- print(
173
- "Model does not have a pad token. We will use the eos token as pad token."
174
- )
175
- tokenizer.pad_token_id = tokenizer.eos_token_id
176
-
177
  if lora_weights_name_or_path:
178
  from peft import PeftModel
179
 
 
14
 
15
  from typing import Optional, Tuple
16
 
 
 
17
  import torch
18
 
19
  import json
 
25
  lora_weights_name_or_path: Optional[str] = None,
26
  torch_dtype: Optional[str] = None,
27
  force_auto_device_map: bool = False,
28
+ trust_remote_code: bool = False,
29
  ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
30
  """
31
  Load any Decoder model for inference.
 
49
  Whether to force the use of the auto device map. If set to True, the model will be split across
50
  GPUs and CPU to fit the model in memory. If set to False, a full copy of the model will be loaded
51
  into each GPU. Defaults to False.
52
+ trust_remote_code (`bool`, optional):
53
+ Trust the remote code from HuggingFace model hub. Defaults to False.
54
 
55
  Returns:
56
  `Tuple[PreTrainedModel, PreTrainedTokenizerBase]`:
 
65
 
66
  print(f"Loading model from {weights_path}")
67
 
 
 
 
 
 
 
 
 
68
  config = AutoConfig.from_pretrained(
69
+ weights_path, trust_remote_code=trust_remote_code
 
 
 
70
  )
71
 
72
  torch_dtype = (
 
74
  )
75
 
76
  if "small100" in weights_path:
77
+ import transformers
78
+
79
+ if transformers.__version__ > "4.34.0":
80
+ raise ValueError(
81
+ "Small100 tokenizer is not supported in transformers > 4.34.0. Please "
82
+ "use transformers <= 4.34.0 if you want to use small100"
83
+ )
84
+
85
  print(f"Loading custom small100 tokenizer for utils.tokenization_small100")
86
  from utils.tokenization_small100 import SMALL100Tokenizer as AutoTokenizer
87
  else:
88
  from transformers import AutoTokenizer
89
 
90
  tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
91
+ weights_path, add_eos_token=True, trust_remote_code=trust_remote_code
 
 
 
 
92
  )
93
 
94
+ if tokenizer.pad_token_id is None:
95
+ if "<|padding|>" in tokenizer.get_vocab():
96
+ # StabilityLM specific fix
97
+ tokenizer.add_special_tokens({"pad_token": "<|padding|>"})
98
+ elif tokenizer.unk_token is not None:
99
+ print(
100
+ "Tokenizer does not have a pad token, we will use the unk token as pad token."
101
+ )
102
+ tokenizer.pad_token_id = tokenizer.unk_token_id
103
+ else:
104
+ print(
105
+ "Tokenizer does not have a pad token. We will use the eos token as pad token."
106
+ )
107
+ tokenizer.pad_token_id = tokenizer.eos_token_id
108
+
109
  quant_args = {}
110
+
111
  if quantization is not None:
112
  quant_args = (
113
  {"load_in_4bit": True} if quantization == 4 else {"load_in_8bit": True}
 
117
  load_in_4bit=True,
118
  bnb_4bit_use_double_quant=True,
119
  bnb_4bit_quant_type="nf4",
120
+ bnb_4bit_compute_dtype=torch.bfloat16
121
+ if torch_dtype in ["auto", None]
122
+ else torch_dtype,
123
  )
 
124
 
125
  else:
126
  bnb_config = BitsAndBytesConfig(
127
  load_in_8bit=True,
128
  )
129
  print(
130
+ f"Bits and Bytes config: {json.dumps(bnb_config.to_dict(), indent=4, ensure_ascii=False)}"
131
  )
132
  else:
133
  print(f"Loading model with dtype: {torch_dtype}")
 
142
  device_map="auto" if force_auto_device_map else None,
143
  torch_dtype=torch_dtype,
144
  quantization_config=bnb_config,
145
+ trust_remote_code=trust_remote_code,
146
  **quant_args,
147
  )
148
 
 
154
  pretrained_model_name_or_path=weights_path,
155
  device_map="auto" if force_auto_device_map else None,
156
  torch_dtype=torch_dtype,
157
+ trust_remote_code=trust_remote_code,
 
 
158
  quantization_config=bnb_config,
159
  **quant_args,
160
  )
 
169
  f"CausalLM: {MODEL_FOR_CAUSAL_LM_MAPPING_NAMES}\n"
170
  )
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  if lora_weights_name_or_path:
173
  from peft import PeftModel
174
 
sample_text/en2es.seamless-m4t-large.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "path": "sample_text/en2es.translation.seamless-m4t-large.txt",
3
+ "sacrebleu": {
4
+ "score": 36.315142112223896,
5
+ "counts": [
6
+ 20334,
7
+ 12742,
8
+ 8758,
9
+ 6156
10
+ ],
11
+ "totals": [
12
+ 31021,
13
+ 30021,
14
+ 29021,
15
+ 28021
16
+ ],
17
+ "precisions": [
18
+ 65.54914412817124,
19
+ 42.44362279737517,
20
+ 30.178146859170944,
21
+ 21.969237357696013
22
+ ],
23
+ "bp": 0.9854077938820913,
24
+ "sys_len": 31021,
25
+ "ref_len": 31477
26
+ },
27
+ "rouge": {
28
+ "rouge1": 0.6330701226501922,
29
+ "rouge2": 0.4284215608900075,
30
+ "rougeL": 0.5852948888167713,
31
+ "rougeLsum": 0.5852893813466102
32
+ },
33
+ "bleu": {
34
+ "bleu": 0.36315142112223897,
35
+ "precisions": [
36
+ 0.6554914412817124,
37
+ 0.4244362279737517,
38
+ 0.30178146859170946,
39
+ 0.21969237357696014
40
+ ],
41
+ "brevity_penalty": 0.9854077938820913,
42
+ "length_ratio": 0.9855132318835975,
43
+ "translation_length": 31021,
44
+ "reference_length": 31477
45
+ },
46
+ "meteor": {
47
+ "meteor": 0.5988659867679048
48
+ },
49
+ "ter": {
50
+ "score": 53.42233524051706,
51
+ "num_edits": 15126,
52
+ "ref_length": 28314.0
53
+ },
54
+ "bert_score": {
55
+ "precision": 0.8355873214006424,
56
+ "recall": 0.8343284497857094,
57
+ "f1": 0.8346186644434929,
58
+ "hashcode": "microsoft/deberta-xlarge-mnli_L40_no-idf_version=0.3.12(hug_trans=4.35.2)_fast-tokenizer"
59
+ }
60
+ }
sample_text/en2es.seamless-m4t-medium.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "path": "sample_text/en2es.translation.seamless-m4t-medium.txt",
3
+ "sacrebleu": {
4
+ "score": 32.86110838375764,
5
+ "counts": [
6
+ 19564,
7
+ 11721,
8
+ 7752,
9
+ 5264
10
+ ],
11
+ "totals": [
12
+ 30811,
13
+ 29811,
14
+ 28811,
15
+ 27812
16
+ ],
17
+ "precisions": [
18
+ 63.49680308980559,
19
+ 39.31770151957331,
20
+ 26.90638992051647,
21
+ 18.92708183517906
22
+ ],
23
+ "bp": 0.978616287348328,
24
+ "sys_len": 30811,
25
+ "ref_len": 31477
26
+ },
27
+ "rouge": {
28
+ "rouge1": 0.609193205717968,
29
+ "rouge2": 0.3944070815557623,
30
+ "rougeL": 0.558841464797821,
31
+ "rougeLsum": 0.5594046328281417
32
+ },
33
+ "bleu": {
34
+ "bleu": 0.3286110838375765,
35
+ "precisions": [
36
+ 0.6349680308980559,
37
+ 0.3931770151957331,
38
+ 0.2690638992051647,
39
+ 0.1892708183517906
40
+ ],
41
+ "brevity_penalty": 0.978616287348328,
42
+ "length_ratio": 0.9788416939352543,
43
+ "translation_length": 30811,
44
+ "reference_length": 31477
45
+ },
46
+ "meteor": {
47
+ "meteor": 0.5707261528520716
48
+ },
49
+ "ter": {
50
+ "score": 55.88754679663771,
51
+ "num_edits": 15824,
52
+ "ref_length": 28314.0
53
+ },
54
+ "bert_score": {
55
+ "precision": 0.8278114783763886,
56
+ "recall": 0.824702616840601,
57
+ "f1": 0.8259151731133461,
58
+ "hashcode": "microsoft/deberta-xlarge-mnli_L40_no-idf_version=0.3.12(hug_trans=4.35.2)_fast-tokenizer"
59
+ }
60
+ }
sample_text/en2es.translation.seamless-m4t-large.txt ADDED
The diff for this file is too large to render. See raw diff
 
sample_text/en2es.translation.seamless-m4t-medium.txt ADDED
The diff for this file is too large to render. See raw diff
 
tests/__init__.py ADDED
File without changes
tests/test_translation.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run with 'python -m unittest tests.test_translation'
2
+
3
+ import unittest
4
+ import tempfile
5
+ import os
6
+ from translate import main
7
+ import transformers
8
+
9
+
10
+ class Inputs(unittest.TestCase):
11
+ def test_m2m100_inputs(self):
12
+ # Create a temporary directory
13
+ with tempfile.TemporaryDirectory() as tmpdirname:
14
+ # Create a temporary file
15
+
16
+ input_path = os.path.join(tmpdirname, "source.txt")
17
+ output_path = os.path.join(tmpdirname, "target.txt")
18
+
19
+ with open(
20
+ os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8"
21
+ ) as f:
22
+ print("Hello, world, my name is Iker!", file=f)
23
+
24
+ main(
25
+ sentences_path=input_path,
26
+ sentences_dir=None,
27
+ files_extension="txt",
28
+ output_path=output_path,
29
+ source_lang="en",
30
+ target_lang="es",
31
+ starting_batch_size=32,
32
+ model_name="facebook/m2m100_418M",
33
+ lora_weights_name_or_path=None,
34
+ force_auto_device_map=True,
35
+ precision=None,
36
+ max_length=64,
37
+ num_beams=2,
38
+ num_return_sequences=1,
39
+ do_sample=False,
40
+ temperature=1.0,
41
+ top_k=50,
42
+ top_p=1.0,
43
+ keep_special_tokens=False,
44
+ keep_tokenization_spaces=False,
45
+ repetition_penalty=None,
46
+ prompt=None,
47
+ )
48
+
49
+ main(
50
+ sentences_path=None,
51
+ sentences_dir=tmpdirname,
52
+ files_extension="txt",
53
+ output_path=os.path.join(tmpdirname, "target"),
54
+ source_lang="en",
55
+ target_lang="es",
56
+ starting_batch_size=32,
57
+ model_name="facebook/m2m100_418M",
58
+ lora_weights_name_or_path=None,
59
+ force_auto_device_map=True,
60
+ precision=None,
61
+ max_length=64,
62
+ num_beams=2,
63
+ num_return_sequences=1,
64
+ do_sample=False,
65
+ temperature=1.0,
66
+ top_k=50,
67
+ top_p=1.0,
68
+ keep_special_tokens=False,
69
+ keep_tokenization_spaces=False,
70
+ repetition_penalty=None,
71
+ prompt=None,
72
+ )
73
+
74
+
75
+ class Translations(unittest.TestCase):
76
+ def test_m2m100(self):
77
+ # Create a temporary directory
78
+ with tempfile.TemporaryDirectory() as tmpdirname:
79
+ # Create a temporary file
80
+
81
+ input_path = os.path.join(tmpdirname, "source.txt")
82
+ output_path = os.path.join(tmpdirname, "target.txt")
83
+
84
+ with open(
85
+ os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8"
86
+ ) as f:
87
+ print("Hello, world, my name is Iker!", file=f)
88
+
89
+ model_name = "facebook/m2m100_418M"
90
+ src_lang = "en"
91
+ tgt_lang = "es"
92
+
93
+ main(
94
+ sentences_path=input_path,
95
+ sentences_dir=None,
96
+ files_extension="txt",
97
+ output_path=output_path,
98
+ source_lang=src_lang,
99
+ target_lang=tgt_lang,
100
+ starting_batch_size=32,
101
+ model_name=model_name,
102
+ lora_weights_name_or_path=None,
103
+ force_auto_device_map=True,
104
+ precision="bf16",
105
+ max_length=64,
106
+ num_beams=2,
107
+ num_return_sequences=1,
108
+ do_sample=False,
109
+ temperature=1.0,
110
+ top_k=50,
111
+ top_p=1.0,
112
+ keep_special_tokens=False,
113
+ keep_tokenization_spaces=False,
114
+ repetition_penalty=None,
115
+ prompt=None,
116
+ )
117
+
118
+ main(
119
+ sentences_path=input_path,
120
+ sentences_dir=None,
121
+ files_extension="txt",
122
+ output_path=output_path,
123
+ source_lang=src_lang,
124
+ target_lang=tgt_lang,
125
+ starting_batch_size=32,
126
+ model_name=model_name,
127
+ lora_weights_name_or_path=None,
128
+ force_auto_device_map=True,
129
+ precision="4",
130
+ max_length=64,
131
+ num_beams=2,
132
+ num_return_sequences=1,
133
+ do_sample=False,
134
+ temperature=1.0,
135
+ top_k=50,
136
+ top_p=1.0,
137
+ keep_special_tokens=False,
138
+ keep_tokenization_spaces=False,
139
+ repetition_penalty=None,
140
+ prompt=None,
141
+ )
142
+
143
+ def test_nllb200(self):
144
+ # Create a temporary directory
145
+ with tempfile.TemporaryDirectory() as tmpdirname:
146
+ # Create a temporary file
147
+
148
+ input_path = os.path.join(tmpdirname, "source.txt")
149
+ output_path = os.path.join(tmpdirname, "target.txt")
150
+
151
+ with open(
152
+ os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8"
153
+ ) as f:
154
+ print("Hello, world, my name is Iker!", file=f)
155
+
156
+ model_name = "facebook/nllb-200-distilled-600M"
157
+ src_lang = "eng_Latn"
158
+ tgt_lang = "spa_Latn"
159
+
160
+ main(
161
+ sentences_path=input_path,
162
+ sentences_dir=None,
163
+ files_extension="txt",
164
+ output_path=output_path,
165
+ source_lang=src_lang,
166
+ target_lang=tgt_lang,
167
+ starting_batch_size=32,
168
+ model_name=model_name,
169
+ lora_weights_name_or_path=None,
170
+ force_auto_device_map=True,
171
+ precision="bf16",
172
+ max_length=64,
173
+ num_beams=2,
174
+ num_return_sequences=1,
175
+ do_sample=False,
176
+ temperature=1.0,
177
+ top_k=50,
178
+ top_p=1.0,
179
+ keep_special_tokens=False,
180
+ keep_tokenization_spaces=False,
181
+ repetition_penalty=None,
182
+ prompt=None,
183
+ )
184
+
185
+ main(
186
+ sentences_path=input_path,
187
+ sentences_dir=None,
188
+ files_extension="txt",
189
+ output_path=output_path,
190
+ source_lang=src_lang,
191
+ target_lang=tgt_lang,
192
+ starting_batch_size=32,
193
+ model_name=model_name,
194
+ lora_weights_name_or_path=None,
195
+ force_auto_device_map=True,
196
+ precision="4",
197
+ max_length=64,
198
+ num_beams=2,
199
+ num_return_sequences=1,
200
+ do_sample=False,
201
+ temperature=1.0,
202
+ top_k=50,
203
+ top_p=1.0,
204
+ keep_special_tokens=False,
205
+ keep_tokenization_spaces=False,
206
+ repetition_penalty=None,
207
+ prompt=None,
208
+ )
209
+
210
+ def test_mbart(self):
211
+ # Create a temporary directory
212
+ with tempfile.TemporaryDirectory() as tmpdirname:
213
+ # Create a temporary file
214
+
215
+ input_path = os.path.join(tmpdirname, "source.txt")
216
+ output_path = os.path.join(tmpdirname, "target.txt")
217
+
218
+ with open(
219
+ os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8"
220
+ ) as f:
221
+ print("Hello, world, my name is Iker!", file=f)
222
+
223
+ model_name = "facebook/mbart-large-50"
224
+ src_lang = "en_XX"
225
+ tgt_lang = "es_XX"
226
+
227
+ main(
228
+ sentences_path=input_path,
229
+ sentences_dir=None,
230
+ files_extension="txt",
231
+ output_path=output_path,
232
+ source_lang=src_lang,
233
+ target_lang=tgt_lang,
234
+ starting_batch_size=32,
235
+ model_name=model_name,
236
+ lora_weights_name_or_path=None,
237
+ force_auto_device_map=True,
238
+ precision="bf16",
239
+ max_length=64,
240
+ num_beams=2,
241
+ num_return_sequences=1,
242
+ do_sample=False,
243
+ temperature=1.0,
244
+ top_k=50,
245
+ top_p=1.0,
246
+ keep_special_tokens=False,
247
+ keep_tokenization_spaces=False,
248
+ repetition_penalty=None,
249
+ prompt=None,
250
+ )
251
+
252
+ main(
253
+ sentences_path=input_path,
254
+ sentences_dir=None,
255
+ files_extension="txt",
256
+ output_path=output_path,
257
+ source_lang=src_lang,
258
+ target_lang=tgt_lang,
259
+ starting_batch_size=32,
260
+ model_name=model_name,
261
+ lora_weights_name_or_path=None,
262
+ force_auto_device_map=True,
263
+ precision="4",
264
+ max_length=64,
265
+ num_beams=2,
266
+ num_return_sequences=1,
267
+ do_sample=False,
268
+ temperature=1.0,
269
+ top_k=50,
270
+ top_p=1.0,
271
+ keep_special_tokens=False,
272
+ keep_tokenization_spaces=False,
273
+ repetition_penalty=None,
274
+ prompt=None,
275
+ )
276
+
277
+ def test_opus(self):
278
+ # Create a temporary directory
279
+ with tempfile.TemporaryDirectory() as tmpdirname:
280
+ # Create a temporary file
281
+
282
+ input_path = os.path.join(tmpdirname, "source.txt")
283
+ output_path = os.path.join(tmpdirname, "target.txt")
284
+
285
+ with open(
286
+ os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8"
287
+ ) as f:
288
+ print("Hello, world, my name is Iker!", file=f)
289
+
290
+ model_name = "Helsinki-NLP/opus-mt-en-es"
291
+ src_lang = None
292
+ tgt_lang = None
293
+
294
+ main(
295
+ sentences_path=input_path,
296
+ sentences_dir=None,
297
+ files_extension="txt",
298
+ output_path=output_path,
299
+ source_lang=src_lang,
300
+ target_lang=tgt_lang,
301
+ starting_batch_size=32,
302
+ model_name=model_name,
303
+ lora_weights_name_or_path=None,
304
+ force_auto_device_map=False,
305
+ precision="bf16",
306
+ max_length=64,
307
+ num_beams=2,
308
+ num_return_sequences=1,
309
+ do_sample=False,
310
+ temperature=1.0,
311
+ top_k=50,
312
+ top_p=1.0,
313
+ keep_special_tokens=False,
314
+ keep_tokenization_spaces=False,
315
+ repetition_penalty=None,
316
+ prompt=None,
317
+ )
318
+
319
+ main(
320
+ sentences_path=input_path,
321
+ sentences_dir=None,
322
+ files_extension="txt",
323
+ output_path=output_path,
324
+ source_lang=src_lang,
325
+ target_lang=tgt_lang,
326
+ starting_batch_size=32,
327
+ model_name=model_name,
328
+ lora_weights_name_or_path=None,
329
+ force_auto_device_map=False,
330
+ precision="4",
331
+ max_length=64,
332
+ num_beams=2,
333
+ num_return_sequences=1,
334
+ do_sample=False,
335
+ temperature=1.0,
336
+ top_k=50,
337
+ top_p=1.0,
338
+ keep_special_tokens=False,
339
+ keep_tokenization_spaces=False,
340
+ repetition_penalty=None,
341
+ prompt=None,
342
+ )
343
+
344
+ @unittest.skipIf(
345
+ transformers.__version__ > "4.34.0",
346
+ "Small100 tokenizer is not supported in transformers > 4.34.0. Please use transformers <= 4.34.0 if you want to use small100",
347
+ )
348
+ def test_small100(self):
349
+ # Create a temporary directory
350
+ with tempfile.TemporaryDirectory() as tmpdirname:
351
+ # Create a temporary file
352
+
353
+ input_path = os.path.join(tmpdirname, "source.txt")
354
+ output_path = os.path.join(tmpdirname, "target.txt")
355
+
356
+ with open(
357
+ os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8"
358
+ ) as f:
359
+ print("Hello, world, my name is Iker!", file=f)
360
+
361
+ model_name = "alirezamsh/small100"
362
+ src_lang = None
363
+ tgt_lang = "es"
364
+
365
+ main(
366
+ sentences_path=input_path,
367
+ sentences_dir=None,
368
+ files_extension="txt",
369
+ output_path=output_path,
370
+ source_lang=src_lang,
371
+ target_lang=tgt_lang,
372
+ starting_batch_size=32,
373
+ model_name=model_name,
374
+ lora_weights_name_or_path=None,
375
+ force_auto_device_map=True,
376
+ precision="bf16",
377
+ max_length=64,
378
+ num_beams=2,
379
+ num_return_sequences=1,
380
+ do_sample=False,
381
+ temperature=1.0,
382
+ top_k=50,
383
+ top_p=1.0,
384
+ keep_special_tokens=False,
385
+ keep_tokenization_spaces=False,
386
+ repetition_penalty=None,
387
+ prompt=None,
388
+ )
389
+
390
+ main(
391
+ sentences_path=input_path,
392
+ sentences_dir=None,
393
+ files_extension="txt",
394
+ output_path=output_path,
395
+ source_lang=src_lang,
396
+ target_lang=tgt_lang,
397
+ starting_batch_size=32,
398
+ model_name=model_name,
399
+ lora_weights_name_or_path=None,
400
+ force_auto_device_map=True,
401
+ precision="4",
402
+ max_length=64,
403
+ num_beams=2,
404
+ num_return_sequences=1,
405
+ do_sample=False,
406
+ temperature=1.0,
407
+ top_k=50,
408
+ top_p=1.0,
409
+ keep_special_tokens=False,
410
+ keep_tokenization_spaces=False,
411
+ repetition_penalty=None,
412
+ prompt=None,
413
+ )
414
+
415
+ def test_seamless(self):
416
+ # Create a temporary directory
417
+ with tempfile.TemporaryDirectory() as tmpdirname:
418
+ # Create a temporary file
419
+
420
+ input_path = os.path.join(tmpdirname, "source.txt")
421
+ output_path = os.path.join(tmpdirname, "target.txt")
422
+
423
+ with open(
424
+ os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8"
425
+ ) as f:
426
+ print("Hello, world, my name is Iker!", file=f)
427
+
428
+ model_name = "facebook/hf-seamless-m4t-medium"
429
+ src_lang = "eng"
430
+ tgt_lang = "spa"
431
+
432
+ main(
433
+ sentences_path=input_path,
434
+ sentences_dir=None,
435
+ files_extension="txt",
436
+ output_path=output_path,
437
+ source_lang=src_lang,
438
+ target_lang=tgt_lang,
439
+ starting_batch_size=32,
440
+ model_name=model_name,
441
+ lora_weights_name_or_path=None,
442
+ force_auto_device_map=True,
443
+ precision="bf16",
444
+ max_length=64,
445
+ num_beams=2,
446
+ num_return_sequences=1,
447
+ do_sample=False,
448
+ temperature=1.0,
449
+ top_k=50,
450
+ top_p=1.0,
451
+ keep_special_tokens=False,
452
+ keep_tokenization_spaces=False,
453
+ repetition_penalty=None,
454
+ prompt=None,
455
+ )
456
+
457
+ main(
458
+ sentences_path=input_path,
459
+ sentences_dir=None,
460
+ files_extension="txt",
461
+ output_path=output_path,
462
+ source_lang=src_lang,
463
+ target_lang=tgt_lang,
464
+ starting_batch_size=32,
465
+ model_name=model_name,
466
+ lora_weights_name_or_path=None,
467
+ force_auto_device_map=True,
468
+ precision="4",
469
+ max_length=64,
470
+ num_beams=2,
471
+ num_return_sequences=1,
472
+ do_sample=False,
473
+ temperature=1.0,
474
+ top_k=50,
475
+ top_p=1.0,
476
+ keep_special_tokens=False,
477
+ keep_tokenization_spaces=False,
478
+ repetition_penalty=None,
479
+ prompt=None,
480
+ )
481
+
482
+
483
+ class Prompting(unittest.TestCase):
484
+ def test_llama(self):
485
+ # Create a temporary directory
486
+ with tempfile.TemporaryDirectory() as tmpdirname:
487
+ # Create a temporary file
488
+
489
+ input_path = os.path.join(tmpdirname, "source.txt")
490
+ output_path = os.path.join(tmpdirname, "target.txt")
491
+
492
+ with open(
493
+ os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8"
494
+ ) as f:
495
+ print("Hello, world, my name is Iker!", file=f)
496
+
497
+ model_name = "stas/tiny-random-llama-2"
498
+ prompt = "Translate English to Spanish: %%SENTENCE%%"
499
+
500
+ main(
501
+ sentences_path=input_path,
502
+ sentences_dir=None,
503
+ files_extension="txt",
504
+ output_path=output_path,
505
+ source_lang=None,
506
+ target_lang=None,
507
+ starting_batch_size=32,
508
+ model_name=model_name,
509
+ lora_weights_name_or_path=None,
510
+ force_auto_device_map=True,
511
+ precision="bf16",
512
+ max_length=64,
513
+ num_beams=2,
514
+ num_return_sequences=1,
515
+ do_sample=True,
516
+ temperature=1.0,
517
+ top_k=50,
518
+ top_p=1.0,
519
+ keep_special_tokens=False,
520
+ keep_tokenization_spaces=False,
521
+ repetition_penalty=None,
522
+ prompt=prompt,
523
+ )
524
+
525
+ main(
526
+ sentences_path=input_path,
527
+ sentences_dir=None,
528
+ files_extension="txt",
529
+ output_path=output_path,
530
+ source_lang=None,
531
+ target_lang=None,
532
+ starting_batch_size=32,
533
+ model_name=model_name,
534
+ lora_weights_name_or_path=None,
535
+ force_auto_device_map=True,
536
+ precision="4",
537
+ max_length=64,
538
+ num_beams=2,
539
+ num_return_sequences=1,
540
+ do_sample=True,
541
+ temperature=1.0,
542
+ top_k=50,
543
+ top_p=1.0,
544
+ keep_special_tokens=False,
545
+ keep_tokenization_spaces=False,
546
+ repetition_penalty=None,
547
+ prompt=prompt,
548
+ )
translate.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import math
3
  import argparse
 
4
 
5
  import torch
6
  from torch.utils.data import DataLoader
@@ -18,6 +19,8 @@ from dataset import DatasetReader, count_lines
18
 
19
  from accelerate import Accelerator, DistributedType, find_executable_batch_size
20
 
 
 
21
 
22
  def encode_string(text):
23
  return text.replace("\r", r"\r").replace("\n", r"\n").replace("\t", r"\t")
@@ -31,7 +34,12 @@ def get_dataloader(
31
  max_length: int,
32
  prompt: str,
33
  ) -> DataLoader:
34
- dataset = DatasetReader(filename, tokenizer, max_length, prompt)
 
 
 
 
 
35
  if accelerator.distributed_type == DistributedType.TPU:
36
  data_collator = DataCollatorForSeq2Seq(
37
  tokenizer,
@@ -59,16 +67,18 @@ def get_dataloader(
59
 
60
 
61
  def main(
62
- sentences_path: str,
 
 
63
  output_path: str,
64
- source_lang: str,
65
- target_lang: str,
66
  starting_batch_size: int,
67
  model_name: str = "facebook/m2m100_1.2B",
68
  lora_weights_name_or_path: str = None,
69
  force_auto_device_map: bool = False,
70
  precision: str = None,
71
- max_length: int = 128,
72
  num_beams: int = 4,
73
  num_return_sequences: int = 1,
74
  do_sample: bool = False,
@@ -79,9 +89,8 @@ def main(
79
  keep_tokenization_spaces: bool = False,
80
  repetition_penalty: float = None,
81
  prompt: str = None,
 
82
  ):
83
- os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
84
-
85
  accelerator = Accelerator()
86
 
87
  if force_auto_device_map and starting_batch_size >= 64:
@@ -92,6 +101,16 @@ def main(
92
  f"inference. You should consider using a smaller batch size, i.e '--starting_batch_size 8'"
93
  )
94
 
 
 
 
 
 
 
 
 
 
 
95
  if precision is None:
96
  quantization = None
97
  dtype = None
@@ -118,11 +137,17 @@ def main(
118
  lora_weights_name_or_path=lora_weights_name_or_path,
119
  torch_dtype=dtype,
120
  force_auto_device_map=force_auto_device_map,
 
121
  )
122
 
123
  is_translation_model = hasattr(tokenizer, "lang_code_to_id")
 
124
 
125
- if is_translation_model and (source_lang is None or target_lang is None):
 
 
 
 
126
  raise ValueError(
127
  f"The model you are using requires a source and target language. "
128
  f"Please specify them with --source-lang and --target-lang. "
@@ -169,8 +194,32 @@ def main(
169
  # We don't need to force the BOS token, so we set is_translation_model to False
170
  is_translation_model = False
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  gen_kwargs = {
173
- "max_length": max_length,
174
  "num_beams": num_beams,
175
  "num_return_sequences": num_return_sequences,
176
  "do_sample": do_sample,
@@ -182,12 +231,17 @@ def main(
182
  if repetition_penalty is not None:
183
  gen_kwargs["repetition_penalty"] = repetition_penalty
184
 
185
- total_lines: int = count_lines(sentences_path)
 
 
 
 
186
 
187
  if accelerator.is_main_process:
188
  print(
189
  f"** Translation **\n"
190
  f"Input file: {sentences_path}\n"
 
191
  f"Output file: {output_path}\n"
192
  f"Source language: {source_lang}\n"
193
  f"Target language: {target_lang}\n"
@@ -211,10 +265,12 @@ def main(
211
  print("\n")
212
 
213
  @find_executable_batch_size(starting_batch_size=starting_batch_size)
214
- def inference(batch_size):
215
- nonlocal model, tokenizer, sentences_path, max_length, output_path, lang_code_to_idx, gen_kwargs, precision, prompt, is_translation_model
 
 
216
 
217
- print(f"Translating with batch size {batch_size}")
218
 
219
  data_loader = get_dataloader(
220
  accelerator=accelerator,
@@ -243,9 +299,6 @@ def main(
243
 
244
  generated_tokens = accelerator.unwrap_model(model).generate(
245
  **batch,
246
- forced_bos_token_id=lang_code_to_idx
247
- if is_translation_model
248
- else None,
249
  **gen_kwargs,
250
  )
251
 
@@ -286,24 +339,60 @@ def main(
286
 
287
  pbar.update(len(tgt_text) // gen_kwargs["num_return_sequences"])
288
 
289
- inference()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  print(f"Translation done.\n")
291
 
292
 
293
  if __name__ == "__main__":
294
  parser = argparse.ArgumentParser(description="Run the translation experiments")
295
- parser.add_argument(
 
296
  "--sentences_path",
 
297
  type=str,
298
- required=True,
299
  help="Path to a txt file containing the sentences to translate. One sentence per line.",
300
  )
301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  parser.add_argument(
303
  "--output_path",
304
  type=str,
305
  required=True,
306
- help="Path to a txt file where the translated sentences will be written.",
 
307
  )
308
 
309
  parser.add_argument(
@@ -355,7 +444,7 @@ if __name__ == "__main__":
355
  parser.add_argument(
356
  "--max_length",
357
  type=int,
358
- default=128,
359
  help="Maximum number of tokens in the source sentence and generated sentence. "
360
  "Increase this value to translate longer sentences, at the cost of increasing memory usage.",
361
  )
@@ -438,10 +527,18 @@ if __name__ == "__main__":
438
  "It must include the special token %%SENTENCE%% which will be replaced by the sentence to translate.",
439
  )
440
 
 
 
 
 
 
 
441
  args = parser.parse_args()
442
 
443
  main(
444
  sentences_path=args.sentences_path,
 
 
445
  output_path=args.output_path,
446
  source_lang=args.source_lang,
447
  target_lang=args.target_lang,
@@ -459,4 +556,5 @@ if __name__ == "__main__":
459
  keep_tokenization_spaces=args.keep_tokenization_spaces,
460
  repetition_penalty=args.repetition_penalty,
461
  prompt=args.prompt,
 
462
  )
 
1
  import os
2
  import math
3
  import argparse
4
+ import glob
5
 
6
  import torch
7
  from torch.utils.data import DataLoader
 
19
 
20
  from accelerate import Accelerator, DistributedType, find_executable_batch_size
21
 
22
+ from typing import Optional
23
+
24
 
25
  def encode_string(text):
26
  return text.replace("\r", r"\r").replace("\n", r"\n").replace("\t", r"\t")
 
34
  max_length: int,
35
  prompt: str,
36
  ) -> DataLoader:
37
+ dataset = DatasetReader(
38
+ filename=filename,
39
+ tokenizer=tokenizer,
40
+ max_length=max_length,
41
+ prompt=prompt,
42
+ )
43
  if accelerator.distributed_type == DistributedType.TPU:
44
  data_collator = DataCollatorForSeq2Seq(
45
  tokenizer,
 
67
 
68
 
69
  def main(
70
+ sentences_path: Optional[str],
71
+ sentences_dir: Optional[str],
72
+ files_extension: str,
73
  output_path: str,
74
+ source_lang: Optional[str],
75
+ target_lang: Optional[str],
76
  starting_batch_size: int,
77
  model_name: str = "facebook/m2m100_1.2B",
78
  lora_weights_name_or_path: str = None,
79
  force_auto_device_map: bool = False,
80
  precision: str = None,
81
+ max_length: int = 256,
82
  num_beams: int = 4,
83
  num_return_sequences: int = 1,
84
  do_sample: bool = False,
 
89
  keep_tokenization_spaces: bool = False,
90
  repetition_penalty: float = None,
91
  prompt: str = None,
92
+ trust_remote_code: bool = False,
93
  ):
 
 
94
  accelerator = Accelerator()
95
 
96
  if force_auto_device_map and starting_batch_size >= 64:
 
101
  f"inference. You should consider using a smaller batch size, i.e '--starting_batch_size 8'"
102
  )
103
 
104
+ if sentences_path is None and sentences_dir is None:
105
+ raise ValueError(
106
+ "You must specify either --sentences_path or --sentences_dir. Use --help for more details."
107
+ )
108
+
109
+ if sentences_path is not None and sentences_dir is not None:
110
+ raise ValueError(
111
+ "You must specify either --sentences_path or --sentences_dir, not both. Use --help for more details."
112
+ )
113
+
114
  if precision is None:
115
  quantization = None
116
  dtype = None
 
137
  lora_weights_name_or_path=lora_weights_name_or_path,
138
  torch_dtype=dtype,
139
  force_auto_device_map=force_auto_device_map,
140
+ trust_remote_code=trust_remote_code,
141
  )
142
 
143
  is_translation_model = hasattr(tokenizer, "lang_code_to_id")
144
+ lang_code_to_idx = None
145
 
146
+ if (
147
+ is_translation_model
148
+ and (source_lang is None or target_lang is None)
149
+ and "small100" not in model_name
150
+ ):
151
  raise ValueError(
152
  f"The model you are using requires a source and target language. "
153
  f"Please specify them with --source-lang and --target-lang. "
 
194
  # We don't need to force the BOS token, so we set is_translation_model to False
195
  is_translation_model = False
196
 
197
+ if model.config.model_type == "seamless_m4t":
198
+ # Loading a seamless_m4t model, we need to set a few things to ensure compatibility
199
+
200
+ supported_langs = tokenizer.additional_special_tokens
201
+ supported_langs = [lang.replace("__", "") for lang in supported_langs]
202
+
203
+ if source_lang is None or target_lang is None:
204
+ raise ValueError(
205
+ f"The model you are using requires a source and target language. "
206
+ f"Please specify them with --source-lang and --target-lang. "
207
+ f"The supported languages are: {supported_langs}"
208
+ )
209
+
210
+ if source_lang not in supported_langs:
211
+ raise ValueError(
212
+ f"Language {source_lang} not found in tokenizer. Available languages: {supported_langs}"
213
+ )
214
+ if target_lang not in supported_langs:
215
+ raise ValueError(
216
+ f"Language {target_lang} not found in tokenizer. Available languages: {supported_langs}"
217
+ )
218
+
219
+ tokenizer.src_lang = source_lang
220
+
221
  gen_kwargs = {
222
+ "max_new_tokens": max_length,
223
  "num_beams": num_beams,
224
  "num_return_sequences": num_return_sequences,
225
  "do_sample": do_sample,
 
231
  if repetition_penalty is not None:
232
  gen_kwargs["repetition_penalty"] = repetition_penalty
233
 
234
+ if is_translation_model:
235
+ gen_kwargs["forced_bos_token_id"] = lang_code_to_idx
236
+
237
+ if model.config.model_type == "seamless_m4t":
238
+ gen_kwargs["tgt_lang"] = target_lang
239
 
240
  if accelerator.is_main_process:
241
  print(
242
  f"** Translation **\n"
243
  f"Input file: {sentences_path}\n"
244
+ f"Sentences dir: {sentences_dir}\n"
245
  f"Output file: {output_path}\n"
246
  f"Source language: {source_lang}\n"
247
  f"Target language: {target_lang}\n"
 
265
  print("\n")
266
 
267
  @find_executable_batch_size(starting_batch_size=starting_batch_size)
268
+ def inference(batch_size, sentences_path, output_path):
269
+ nonlocal model, tokenizer, max_length, gen_kwargs, precision, prompt, is_translation_model
270
+
271
+ print(f"Translating {sentences_path} with batch size {batch_size}")
272
 
273
+ total_lines: int = count_lines(sentences_path)
274
 
275
  data_loader = get_dataloader(
276
  accelerator=accelerator,
 
299
 
300
  generated_tokens = accelerator.unwrap_model(model).generate(
301
  **batch,
 
 
 
302
  **gen_kwargs,
303
  )
304
 
 
339
 
340
  pbar.update(len(tgt_text) // gen_kwargs["num_return_sequences"])
341
 
342
+ print(f"Translation done. Output written to {output_path}\n")
343
+
344
+ if sentences_path is not None:
345
+ os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
346
+ inference(sentences_path=sentences_path, output_path=output_path)
347
+
348
+ if sentences_dir is not None:
349
+ print(
350
+ f"Translating all files in {sentences_dir}, with extension {files_extension}"
351
+ )
352
+ os.makedirs(os.path.abspath(output_path), exist_ok=True)
353
+ for filename in glob.glob(
354
+ os.path.join(
355
+ sentences_dir, f"*.{files_extension}" if files_extension else "*"
356
+ )
357
+ ):
358
+ output_filename = os.path.join(output_path, os.path.basename(filename))
359
+ inference(sentences_path=filename, output_path=output_filename)
360
+
361
  print(f"Translation done.\n")
362
 
363
 
364
  if __name__ == "__main__":
365
  parser = argparse.ArgumentParser(description="Run the translation experiments")
366
+ input_group = parser.add_mutually_exclusive_group(required=True)
367
+ input_group.add_argument(
368
  "--sentences_path",
369
+ default=None,
370
  type=str,
 
371
  help="Path to a txt file containing the sentences to translate. One sentence per line.",
372
  )
373
 
374
+ input_group.add_argument(
375
+ "--sentences_dir",
376
+ type=str,
377
+ default=None,
378
+ help="Path to a directory containing the sentences to translate. "
379
+ "Sentences must be in .txt files containing containing one sentence per line.",
380
+ )
381
+
382
+ parser.add_argument(
383
+ "--files_extension",
384
+ type=str,
385
+ default="txt",
386
+ help="If sentences_dir is specified, extension of the files to translate. Defaults to txt. "
387
+ "If set to an empty string, we will translate all files in the directory.",
388
+ )
389
+
390
  parser.add_argument(
391
  "--output_path",
392
  type=str,
393
  required=True,
394
+ help="Path to a txt file where the translated sentences will be written. If the input is a directory, "
395
+ "the output will be a directory with the same structure.",
396
  )
397
 
398
  parser.add_argument(
 
444
  parser.add_argument(
445
  "--max_length",
446
  type=int,
447
+ default=256,
448
  help="Maximum number of tokens in the source sentence and generated sentence. "
449
  "Increase this value to translate longer sentences, at the cost of increasing memory usage.",
450
  )
 
527
  "It must include the special token %%SENTENCE%% which will be replaced by the sentence to translate.",
528
  )
529
 
530
+ parser.add_argument(
531
+ "--trust_remote_code",
532
+ action="store_true",
533
+ help="If set we will trust remote code in HuggingFace models. This is required for some models.",
534
+ )
535
+
536
  args = parser.parse_args()
537
 
538
  main(
539
  sentences_path=args.sentences_path,
540
+ sentences_dir=args.sentences_dir,
541
+ files_extension=args.files_extension,
542
  output_path=args.output_path,
543
  source_lang=args.source_lang,
544
  target_lang=args.target_lang,
 
556
  keep_tokenization_spaces=args.keep_tokenization_spaces,
557
  repetition_penalty=args.repetition_penalty,
558
  prompt=args.prompt,
559
+ trust_remote_code=args.trust_remote_code,
560
  )