PFEemp2024's picture
add necessary file
63775f2
raw
history blame contribute delete
No virus
1.53 kB
"""
Ted Multi TranslationDataset Class
------------------------------------
"""
import collections
import datasets
import numpy as np
from textattack.datasets import HuggingFaceDataset
class TedMultiTranslationDataset(HuggingFaceDataset):
"""Loads examples from the Ted Talk translation dataset using the
`datasets` package.
dataset source: http://www.cs.jhu.edu/~kevinduh/a/multitarget-tedtalks/
"""
def __init__(self, source_lang="en", target_lang="de", split="test"):
self._dataset = datasets.load_dataset("ted_multi")[split]
self.examples = self._dataset["translations"]
language_options = set(self.examples[0]["language"])
if source_lang not in language_options:
raise ValueError(
f"Source language {source_lang} invalid. Choices: {sorted(language_options)}"
)
if target_lang not in language_options:
raise ValueError(
f"Target language {target_lang} invalid. Choices: {sorted(language_options)}"
)
self.source_lang = source_lang
self.target_lang = target_lang
def _format_raw_example(self, raw_example):
translations = np.array(raw_example["translation"])
languages = np.array(raw_example["language"])
source = translations[languages == self.source_lang][0]
target = translations[languages == self.target_lang][0]
source_dict = collections.OrderedDict([("Source", source)])
return (source_dict, target)