Spaces:
Runtime error
Runtime error
import itertools | |
from torch.utils.data import Dataset | |
import numpy as np | |
from joblib import Parallel, delayed | |
class OnePersonaDataset(Dataset): | |
def __init__(self, data, tokenizer, transforms=None, positive_candidates=True, n_jobs=8): | |
super().__init__() | |
self.data = data | |
if len(data) == 0: | |
self.input_ids = [] | |
self.history = [] | |
self.labels = [] | |
return | |
if positive_candidates: | |
self.history = [row['history'] + [row['candidates'][-1], ] for row in data] | |
self.labels = np.ones(len(self.history), dtype=int) | |
else: | |
self.history = [row['history'] + [candidate, ] for row in data | |
for candidate in row['candidates']] | |
self.labels = itertools.chain.from_iterable([0] * (len(row['candidates']) - 1) + [1] | |
for row in data) | |
self.labels = np.array(self.labels, dtype=int) | |
if transforms is None: | |
self.history = ["\n".join(item) for item in self.history] | |
else: | |
self.history = Parallel(n_jobs=n_jobs)(delayed(transforms)(item) for item in self.history) | |
self.input_ids = tokenizer(self.history, padding='max_length', truncation=True)["input_ids"] | |
def __getitem__(self, idx): | |
return {'input_ids': self.input_ids[idx], | |
'labels': self.input_ids[idx], | |
'example': self.history[idx], | |
'class': self.labels[idx]} | |
def __len__(self): | |
return len(self.data) | |
class PersonaChatDataset(Dataset): | |
DEFAULT_DATASET_NAME = "bavard/personachat_truecased" | |
def __init__(self, clustering, dataset, tokenizer): | |
super().__init__() | |
self.dataset = dataset | |
self.clustering = clustering | |
all_personalities = list(set([sent for item in self.dataset | |
for sent in item['personality']])) | |
predicted_centers = self.clustering.predict(all_personalities) | |
self.all_personalities_to_id = {persona: center | |
for persona, center in zip(all_personalities, predicted_centers)} | |
self.personalities = self.clustering._cluster_centers | |
subdataset_data_by_personality = [[] for _ in range(len(self.personalities))] | |
for i in range(len(self.dataset)): | |
item = self.dataset[i] | |
cur_persona_ids = [self.all_personalities_to_id[persona] for persona in item['personality']] | |
for persona_id in cur_persona_ids: | |
subdataset_data_by_personality[persona_id].append(item) | |
self.subdatasets = [OnePersonaDataset(cur_data, tokenizer) for cur_data in subdataset_data_by_personality] | |
def __getitem__(self, persona_id): | |
return self.subdatasets[persona_id] | |
def __len__(self, ): | |
return len(self.datasets) |