Build error
Build error
File size: 18,066 Bytes
f3772cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup
import datasets
import pandas as pd
import pyarrow
import pytorch_lightning as pl
import torchmetrics
import torch.nn as nn
import torch
import types
import multiprocessing
from .text_cleaning import clean_text_funcs
class RRUMDataset():
scalar_features = ['channel_sim']
_image_features = ['regret_thumbnail',
'recommendation_thumbnail'] # not used atm
def __init__(self, data, with_transcript, cross_encoder_model_name_or_path, label_col="label", label_map=None, balance_label_counts=False, max_length=128, do_train_test_split=False, test_size=0.25, seed=42, keep_video_ids_for_predictions=False, encode_on_the_fly=False, clean_text=False, processing_batch_size=1000, processing_num_proc=1):
self._with_transcript = with_transcript
self.tokenizer = AutoTokenizer.from_pretrained(
self.label_col = label_col
self.label_map = label_map
self.balance_label_counts = balance_label_counts
self.max_length = max_length
self.seed = seed
self.keep_video_ids_for_predictions = keep_video_ids_for_predictions
self.clean_text = clean_text
self.processing_batch_size = processing_batch_size
self.processing_num_proc = multiprocessing.cpu_count(
) if not processing_num_proc else processing_num_proc
self.text_types = ['title', 'description'] + \
(['transcript'] if self._with_transcript else [])
self._text_features = [
'regret_title', 'recommendation_title', 'regret_description',
'recommendation_description'] + (['regret_transcript', 'recommendation_transcript'] if self._with_transcript else [])
self.streaming_dataset = False
if isinstance(data, pd.DataFrame):
self.dataset = datasets.Dataset.from_pandas(data)
elif isinstance(data, types.GeneratorType):
examples_iterable = datasets.iterable_dataset.ExamplesIterable(
self._streaming_generate_examples, {"iterable": data})
self.dataset = datasets.IterableDataset(examples_iterable)
self._stream_dataset_example = next(iter(self.dataset))
self._stream_dataset_column_names = list(
self.streaming_dataset = True
elif isinstance(data, pyarrow.Table):
self.dataset = datasets.Dataset(data)
raise ValueError(
f'Type of data is {type(data)} when pd.DataFrame, pyarrow.Table, or generator of pyarrow.RecordBatch is allowed')
self.train_dataset = None
self.test_dataset = None
if self.streaming_dataset:
# IterableDataset doesn't have train_test_split method
if self.label_col:
self.train_dataset = self._encode_streaming(self.dataset)
print('Streaming dataset available in .train_dataset')
self.test_dataset = self._encode_streaming(self.dataset)
'Streaming dataset available in .test_dataset because label_col=None')
# dataset into train_dataset and/or test_dataset
if do_train_test_split:
ds = self.dataset.train_test_split(
test_size=test_size, shuffle=True, seed=self.seed, stratify_by_column=self.label_col)
self.train_dataset = ds['train']
self.test_dataset = ds['test']
f'Dataset was splitted into train and test with test_size={test_size}')
if self.label_col:
self.train_dataset = self.dataset
self.test_dataset = self.dataset
if encode_on_the_fly:
if self.train_dataset:
print('On-the-fly encoded dataset available in .train_dataset')
if self.test_dataset:
print('On-the-fly encoded dataset available in .test_dataset')
if self.train_dataset:
self.train_dataset = self._encode(self.train_dataset)
print('Pre-encoded dataset available in .train_dataset')
if self.test_dataset:
self.test_dataset = self._encode(self.test_dataset)
print('Pre-encoded dataset available in .test_dataset')
def __len__(self):
if self.streaming_dataset:
raise ValueError(
f'Streaming dataset does not support len() method')
return len(self.dataset)
def __getitem__(self, index):
if self.streaming_dataset:
return next(iter(self.dataset))
return self.dataset[index]
def _streaming_generate_examples(self, iterable):
id_ = 0
# TODO: make sure GeneratorType is pyarrow.RecordBatch
if isinstance(iterable, types.GeneratorType):
for examples in iterable:
for ex in examples.to_pylist():
yield id_, ex
id_ += 1
def _preprocess(self):
if self._with_transcript:
self.dataset = self.dataset.filter(
lambda example: example['regret_transcript'] is not None and example['recommendation_transcript'] is not None)
self.dataset = self.dataset.filter(
lambda example: example['regret_transcript'] is None or example['recommendation_transcript'] is None)
if self.label_col:
if self.streaming_dataset:
if self.label_col in self._stream_dataset_column_names and isinstance(self._stream_dataset_example[self.label_col], str):
if not self.label_map:
raise ValueError(
f'"label_map" dict was not provided and is needed to encode string labels for streaming datasets')
# cast_column method had issues with streaming dataset
self.dataset =
if self.dataset.features[self.label_col].dtype == 'string':
if not self.label_map:
self.label_map = {k: v for v, k in enumerate(
self.dataset = self.dataset.filter(
lambda example: example[self.label_col] in self.label_map.keys())
self.dataset = self.dataset.cast_column(self.label_col, datasets.ClassLabel(
num_classes=len(self.label_map), names=list(self.label_map.keys())))
self.dataset = self.dataset.filter(lambda example: not any(x in [None, ""] for x in [
example[key] for key in self._text_features + self.scalar_features + ([self.label_col] if self.label_col else [])])) # dropna
if self.balance_label_counts and self.label_col and not self.streaming_dataset:
label_datasets = {}
for label in list(self.label_map.values()):
label_dataset = self.dataset.filter(
lambda example: example[self.label_col] == label)
label_datasets[len(label_dataset)] = label_dataset
min_label_count = min(label_datasets)
sampled_datasets = [dataset.train_test_split(train_size=min_label_count, shuffle=True, seed=self.seed)[
'train'] if len(dataset) != min_label_count else dataset for dataset in label_datasets.values()]
self.dataset = datasets.concatenate_datasets(sampled_datasets)
if self.clean_text:
self.dataset =, batched=not self.streaming_dataset,
self.dataset =, batched=not self.streaming_dataset,
def _streaming_rename_labels(self, example):
# rename labels according to label_map if not already correct labels
if isinstance(example[self.label_col], list):
example[self.label_col] = [self.label_map.get(
ex, None) for ex in example[self.label_col] if ex not in self.label_map.values()]
elif isinstance(example[self.label_col], str) and example[self.label_col] not in self.label_map.values():
example[self.label_col] = self.label_map.get(
example[self.label_col], None)
raise ValueError(
f'Type of example label is {type(example[self.label_col])} when list or string is allowed')
return example
def _clean_text(self, example):
for feat in self._text_features:
example[feat] = clean_text_funcs(example[feat])[0] if isinstance(
example[feat], str) else clean_text_funcs(example[feat])
return example
def _truncate_and_strip_text(self, example):
# tokenizer will truncate to max_length tokens anyway so to save RAM let's truncate to max_length words already beforehand
# one word is usually one or more tokens so should be safe to truncate this way without losing information
for feat in self._text_features:
if isinstance(example[feat], list):
example[feat] = [
' '.join(text.split()[:self.max_length]).strip() for text in example[feat] if text]
elif isinstance(example[feat], str):
example[feat] = ' '.join(example[feat].split()[
elif example[feat] is None:
return None
raise ValueError(
f'Type of example is {type(example[feat])} when list or string is allowed')
return example
def _encode(self, dataset):
encoded_dataset = None
for text_type in self.text_types:
encoded_text_type = regret, recommendation: self.tokenizer(regret, recommendation, padding="max_length", truncation=True, max_length=self.max_length), batched=True,
batch_size=self.processing_batch_size, num_proc=self.processing_num_proc, input_columns=[f'regret_{text_type}', f'recommendation_{text_type}'], remove_columns=dataset.column_names)
encoded_text_type = encoded_text_type.rename_columns(
{col: f'{text_type}_{col}' for col in encoded_text_type.column_names}) # e.g. input_ids -> title_input_ids so we have separate input_ids for each text_type
if encoded_dataset:
encoded_dataset = datasets.concatenate_datasets(
[encoded_dataset, encoded_text_type], axis=1)
encoded_dataset = encoded_text_type
# copy scalar features and label from original dataset to the encoded dataset
for scalar_feat in self.scalar_features:
encoded_dataset = encoded_dataset.add_column(
name=scalar_feat, column=dataset[scalar_feat])
if self.label_col:
encoded_dataset = encoded_dataset.add_column(
name=self.label_col, column=dataset[self.label_col])
if self.keep_video_ids_for_predictions:
for id in ['regret_id', "recommendation_id"]:
encoded_dataset = encoded_dataset.add_column(
name=id, column=dataset[id])
type='torch', columns=encoded_dataset.column_names)
return encoded_dataset
def _encode_streaming(self, dataset):
encoded_dataset =, batched=True,
batch_size=self.processing_batch_size, remove_columns=list(set(self._stream_dataset_column_names)-set(self.scalar_features + (
[self.label_col] if self.label_col else []) + (['regret_id', "recommendation_id"] if self.keep_video_ids_for_predictions else [])))) # IterableDataset doesn't have column_names attribute as normal Dataset
encoded_dataset = encoded_dataset.with_format("torch")
return encoded_dataset
def _encode_on_the_fly(self, batch):
for text_type in self.text_types:
encoded_text_type = dict(self.tokenizer(
batch[f'regret_{text_type}'], batch[f'recommendation_{text_type}'], padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt"))
for encoded_key in encoded_text_type.copy():
encoded_text_type[f"{text_type}_{encoded_key}"] = encoded_text_type.pop(encoded_key) if not self.streaming_dataset else encoded_text_type.pop(
encoded_key).squeeze(0) # e.g. input_ids -> title_input_ids so we have separate input_ids for each text_type
del batch[f'regret_{text_type}']
del batch[f'recommendation_{text_type}']
for scalar_feat in self.scalar_features:
batch[scalar_feat] = torch.as_tensor(
batch[scalar_feat]) if not self.streaming_dataset else torch.as_tensor(batch[scalar_feat]).squeeze(0)
if self.label_col:
batch[self.label_col] = torch.as_tensor(
batch[self.label_col]) if not self.streaming_dataset else torch.as_tensor(batch[self.label_col]).squeeze(0)
return batch
class RRUM(pl.LightningModule):
def __init__(self, text_types, scalar_features, label_col, cross_encoder_model_name_or_path, optimizer_config=None, freeze_policy=None, pos_weight=None):
self.text_types = text_types
self.scalar_features = scalar_features
self.label_col = label_col
self.optimizer_config = optimizer_config
self.cross_encoder_model_name_or_path = cross_encoder_model_name_or_path
self.cross_encoders = nn.ModuleDict({})
for t in self.text_types:
self.cross_encoders[t] = AutoModelForSequenceClassification.from_pretrained(
if freeze_policy is not None:
for xe in self.cross_encoders.values():
for name, param in xe.named_parameters():
if freeze_policy(name):
param.requires_grad = False
cross_encoder_out_features = list(self.cross_encoders.values())[0](
torch.randint(1, 2, (1, 2))).logits.size(dim=1)
self.lin1 = nn.Linear(len(self.cross_encoders) * cross_encoder_out_features +
len(self.scalar_features), 1)
self.ac_metric = torchmetrics.Accuracy()
self.pr_metric = torchmetrics.Precision()
self.re_metric = torchmetrics.Recall()
self.auc_metric = torchmetrics.AUROC()
if pos_weight:
self.loss = nn.BCEWithLogitsLoss(
self.loss = nn.BCEWithLogitsLoss()
def forward(self, x):
cross_logits = {}
for f in self.text_types:
inputs = {key.split(f'{f}_')[1]: x[key]
for key in x if f in key} # e.g. title_input_ids -> input_ids since we have separate input_ids for each text_type
cross_logits[f] = self.cross_encoders[f](**inputs).logits
x =[*cross_logits.values()] +
[x[scalar][:, None] for scalar in self.scalar_features],
del cross_logits
x = self.lin1(x)
return x
def configure_optimizers(self):
if self.optimizer_config:
return self.optimizer_config(self)
optimizer = torch.optim.AdamW(self.parameters(), lr=5e-5)
scheduler = get_linear_schedule_with_warmup(
self.trainer.estimated_stepping_batches * 0.05),
scheduler = {'scheduler': scheduler,
'interval': 'step', 'frequency': 1}
return [optimizer], [scheduler]
def training_step(self, train_batch, batch_idx):
y = train_batch[self.label_col].unsqueeze(1).float()
logits = self(train_batch)
loss = self.loss(logits, y)
self.log('train_loss', loss)
return loss
def validation_step(self, val_batch, batch_idx):
y = val_batch[self.label_col].unsqueeze(1).float()
logits = self(val_batch)
loss = self.loss(logits, y)
self.log('validation_accuracy', self.ac_metric)
self.log('validation_precision', self.pr_metric)
self.log('validation_recall', self.re_metric)
self.log('validation_auc', self.auc_metric)
self.log('val_loss', loss, prog_bar=True)
def validation_epoch_end(self, outputs):
self.log('validation_accuracy_ep', self.ac_metric)
self.log('validation_precision_ep', self.pr_metric)
self.log('validation_recall_ep', self.re_metric)
self.log('validation_auc_ep', self.auc_metric)