File size: 4,971 Bytes
63775f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""
BackTranslation class
-----------------------------------

"""


import random

from transformers import MarianMTModel, MarianTokenizer

from textattack.shared import AttackedText

from .sentence_transformation import SentenceTransformation


class BackTranslation(SentenceTransformation):
    """A type of sentence level transformation that takes in a text input,
    translates it into target language and translates it back to source
    language.

    letters_to_insert (string): letters allowed for insertion into words
    (used by some char-based transformations)

    src_lang (string): source language
    target_lang (string): target language, for the list of supported language check bottom of this page
    src_model: translation model from huggingface that translates from source language to target language
    target_model: translation model from huggingface that translates from target language to source language
    chained_back_translation: run back translation in a chain for more perturbation (for example, en-es-en-fr-en)

    Example::

        >>> from textattack.transformations.sentence_transformations import BackTranslation
        >>> from textattack.constraints.pre_transformation import RepeatModification, StopwordModification
        >>> from textattack.augmentation import Augmenter

        >>> transformation = BackTranslation()
        >>> constraints = [RepeatModification(), StopwordModification()]
        >>> augmenter = Augmenter(transformation = transformation, constraints = constraints)
        >>> s = 'What on earth are you doing here.'

        >>> augmenter.augment(s)
    """

    def __init__(
        self,
        src_lang="en",
        target_lang="es",
        src_model="Helsinki-NLP/opus-mt-ROMANCE-en",
        target_model="Helsinki-NLP/opus-mt-en-ROMANCE",
        chained_back_translation=0,
    ):
        self.src_lang = src_lang
        self.target_lang = target_lang
        self.target_model = MarianMTModel.from_pretrained(target_model)
        self.target_tokenizer = MarianTokenizer.from_pretrained(target_model)
        self.src_model = MarianMTModel.from_pretrained(src_model)
        self.src_tokenizer = MarianTokenizer.from_pretrained(src_model)
        self.chained_back_translation = chained_back_translation

    def translate(self, input, model, tokenizer, lang="es"):
        # change the text to model's format
        src_texts = []
        if lang == "en":
            src_texts.append(input[0])
        else:
            if ">>" and "<<" not in lang:
                lang = ">>" + lang + "<< "
            src_texts.append(lang + input[0])

        # tokenize the input
        encoded_input = tokenizer.prepare_seq2seq_batch(src_texts, return_tensors="pt")

        # translate the input
        translated = model.generate(**encoded_input)
        translated_input = tokenizer.batch_decode(translated, skip_special_tokens=True)
        return translated_input

    def _get_transformations(self, current_text, indices_to_modify):
        transformed_texts = []
        current_text = current_text.text

        # to perform chained back translation, a random list of target languages are selected from the provided model
        if self.chained_back_translation:
            list_of_target_lang = random.sample(
                self.target_tokenizer.supported_language_codes,
                self.chained_back_translation,
            )
            for target_lang in list_of_target_lang:
                target_language_text = self.translate(
                    [current_text],
                    self.target_model,
                    self.target_tokenizer,
                    target_lang,
                )
                src_language_text = self.translate(
                    target_language_text,
                    self.src_model,
                    self.src_tokenizer,
                    self.src_lang,
                )
                current_text = src_language_text[0]
            return [AttackedText(current_text)]

        # translates source to target language and back to source language (single back translation)
        target_language_text = self.translate(
            [current_text], self.target_model, self.target_tokenizer, self.target_lang
        )
        src_language_text = self.translate(
            target_language_text, self.src_model, self.src_tokenizer, self.src_lang
        )
        transformed_texts.append(AttackedText(src_language_text[0]))
        return transformed_texts


"""
List of supported languages
['fr',
 'es',
 'it',
 'pt',
 'pt_br',
 'ro',
 'ca',
 'gl',
 'pt_BR<<',
 'la<<',
 'wa<<',
 'fur<<',
 'oc<<',
 'fr_CA<<',
 'sc<<',
 'es_ES',
 'es_MX',
 'es_AR',
 'es_PR',
 'es_UY',
 'es_CL',
 'es_CO',
 'es_CR',
 'es_GT',
 'es_HN',
 'es_NI',
 'es_PA',
 'es_PE',
 'es_VE',
 'es_DO',
 'es_EC',
 'es_SV',
 'an',
 'pt_PT',
 'frp',
 'lad',
 'vec',
 'fr_FR',
 'co',
 'it_IT',
 'lld',
 'lij',
 'lmo',
 'nap',
 'rm',
 'scn',
 'mwl']
"""