File size: 2,924 Bytes
c1c5bd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import itertools
from torch.utils.data import Dataset
import numpy as np
from joblib import Parallel, delayed


class OnePersonaDataset(Dataset):
    def __init__(self, data, tokenizer, transforms=None, positive_candidates=True, n_jobs=8):
        super().__init__()

        self.data = data
        if len(data) == 0:
            self.input_ids = []
            self.history = []
            self.labels = []
            return

        if positive_candidates:
            self.history = [row['history'] + [row['candidates'][-1], ] for row in data]
            self.labels = np.ones(len(self.history), dtype=int)
        else:
            self.history = [row['history'] + [candidate, ] for row in data
                            for candidate in row['candidates']]
            self.labels = itertools.chain.from_iterable([0] * (len(row['candidates']) - 1) + [1]
                                                        for row in data)
            self.labels = np.array(self.labels, dtype=int)

        if transforms is None:
            self.history = ["\n".join(item) for item in self.history]
        else:
            self.history = Parallel(n_jobs=n_jobs)(delayed(transforms)(item) for item in self.history)
        self.input_ids = tokenizer(self.history, padding='max_length', truncation=True)["input_ids"]

    def __getitem__(self, idx):
        return {'input_ids': self.input_ids[idx],
                'labels': self.input_ids[idx],
                'example': self.history[idx],
                'class': self.labels[idx]}

    def __len__(self):
        return len(self.data)


class PersonaChatDataset(Dataset):
    DEFAULT_DATASET_NAME = "bavard/personachat_truecased"

    def __init__(self, clustering, dataset, tokenizer):
        super().__init__()

        self.dataset = dataset
        self.clustering = clustering

        all_personalities = list(set([sent for item in self.dataset
                                      for sent in item['personality']]))
        predicted_centers = self.clustering.predict(all_personalities)
        self.all_personalities_to_id = {persona: center
                                        for persona, center in zip(all_personalities, predicted_centers)}
        self.personalities = self.clustering._cluster_centers

        subdataset_data_by_personality = [[] for _ in range(len(self.personalities))]

        for i in range(len(self.dataset)):
            item = self.dataset[i]
            cur_persona_ids = [self.all_personalities_to_id[persona] for persona in item['personality']]
            for persona_id in cur_persona_ids:
                subdataset_data_by_personality[persona_id].append(item)

        self.subdatasets = [OnePersonaDataset(cur_data, tokenizer) for cur_data in subdataset_data_by_personality]

    def __getitem__(self, persona_id):
        return self.subdatasets[persona_id]

    def __len__(self, ):
        return len(self.datasets)