j.gilyazev
add personalized-chat-bot
c1c5bd9
import itertools
from torch.utils.data import Dataset
import numpy as np
from joblib import Parallel, delayed
class OnePersonaDataset(Dataset):
def __init__(self, data, tokenizer, transforms=None, positive_candidates=True, n_jobs=8):
super().__init__()
self.data = data
if len(data) == 0:
self.input_ids = []
self.history = []
self.labels = []
return
if positive_candidates:
self.history = [row['history'] + [row['candidates'][-1], ] for row in data]
self.labels = np.ones(len(self.history), dtype=int)
else:
self.history = [row['history'] + [candidate, ] for row in data
for candidate in row['candidates']]
self.labels = itertools.chain.from_iterable([0] * (len(row['candidates']) - 1) + [1]
for row in data)
self.labels = np.array(self.labels, dtype=int)
if transforms is None:
self.history = ["\n".join(item) for item in self.history]
else:
self.history = Parallel(n_jobs=n_jobs)(delayed(transforms)(item) for item in self.history)
self.input_ids = tokenizer(self.history, padding='max_length', truncation=True)["input_ids"]
def __getitem__(self, idx):
return {'input_ids': self.input_ids[idx],
'labels': self.input_ids[idx],
'example': self.history[idx],
'class': self.labels[idx]}
def __len__(self):
return len(self.data)
class PersonaChatDataset(Dataset):
DEFAULT_DATASET_NAME = "bavard/personachat_truecased"
def __init__(self, clustering, dataset, tokenizer):
super().__init__()
self.dataset = dataset
self.clustering = clustering
all_personalities = list(set([sent for item in self.dataset
for sent in item['personality']]))
predicted_centers = self.clustering.predict(all_personalities)
self.all_personalities_to_id = {persona: center
for persona, center in zip(all_personalities, predicted_centers)}
self.personalities = self.clustering._cluster_centers
subdataset_data_by_personality = [[] for _ in range(len(self.personalities))]
for i in range(len(self.dataset)):
item = self.dataset[i]
cur_persona_ids = [self.all_personalities_to_id[persona] for persona in item['personality']]
for persona_id in cur_persona_ids:
subdataset_data_by_personality[persona_id].append(item)
self.subdatasets = [OnePersonaDataset(cur_data, tokenizer) for cur_data in subdataset_data_by_personality]
def __getitem__(self, persona_id):
return self.subdatasets[persona_id]
def __len__(self, ):
return len(self.datasets)