File size: 1,527 Bytes
63775f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""

Ted Multi TranslationDataset Class
------------------------------------
"""


import collections

import datasets
import numpy as np

from textattack.datasets import HuggingFaceDataset


class TedMultiTranslationDataset(HuggingFaceDataset):
    """Loads examples from the Ted Talk translation dataset using the
    `datasets` package.

    dataset source: http://www.cs.jhu.edu/~kevinduh/a/multitarget-tedtalks/
    """

    def __init__(self, source_lang="en", target_lang="de", split="test"):
        self._dataset = datasets.load_dataset("ted_multi")[split]
        self.examples = self._dataset["translations"]
        language_options = set(self.examples[0]["language"])
        if source_lang not in language_options:
            raise ValueError(
                f"Source language {source_lang} invalid. Choices: {sorted(language_options)}"
            )
        if target_lang not in language_options:
            raise ValueError(
                f"Target language {target_lang} invalid. Choices: {sorted(language_options)}"
            )
        self.source_lang = source_lang
        self.target_lang = target_lang

    def _format_raw_example(self, raw_example):
        translations = np.array(raw_example["translation"])
        languages = np.array(raw_example["language"])
        source = translations[languages == self.source_lang][0]
        target = translations[languages == self.target_lang][0]
        source_dict = collections.OrderedDict([("Source", source)])
        return (source_dict, target)