Spaces:
Running
Running
File size: 12,833 Bytes
3affa92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 |
# -*- coding: utf-8 -*-
""" Class average finetuning functions. Before using any of these finetuning
functions, ensure that the model is set up with nb_classes=2.
"""
from __future__ import print_function
import uuid
from time import sleep
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchmoji.global_variables import (
FINETUNING_METHODS,
WEIGHTS_DIR)
from torchmoji.finetuning import (
freeze_layers,
get_data_loader,
fit_model,
train_by_chain_thaw,
find_f1_threshold)
def relabel(y, current_label_nr, nb_classes):
""" Makes a binary classification for a specific class in a
multi-class dataset.
# Arguments:
y: Outputs to be relabelled.
current_label_nr: Current label number.
nb_classes: Total number of classes.
# Returns:
Relabelled outputs of a given multi-class dataset into a binary
classification dataset.
"""
# Handling binary classification
if nb_classes == 2 and len(y.shape) == 1:
return y
y_new = np.zeros(len(y))
y_cut = y[:, current_label_nr]
label_pos = np.where(y_cut == 1)[0]
y_new[label_pos] = 1
return y_new
def class_avg_finetune(model, texts, labels, nb_classes, batch_size,
method, epoch_size=5000, nb_epochs=1000, embed_l2=1E-6,
verbose=True):
""" Compiles and finetunes the given model.
# Arguments:
model: Model to be finetuned
texts: List of three lists, containing tokenized inputs for training,
validation and testing (in that order).
labels: List of three lists, containing labels for training,
validation and testing (in that order).
nb_classes: Number of classes in the dataset.
batch_size: Batch size.
method: Finetuning method to be used. For available methods, see
FINETUNING_METHODS in global_variables.py. Note that the model
should be defined accordingly (see docstring for torchmoji_transfer())
epoch_size: Number of samples in an epoch.
nb_epochs: Number of epochs. Doesn't matter much as early stopping is used.
embed_l2: L2 regularization for the embedding layer.
verbose: Verbosity flag.
# Returns:
Model after finetuning,
score after finetuning using the class average F1 metric.
"""
if method not in FINETUNING_METHODS:
raise ValueError('ERROR (class_avg_tune_trainable): '
'Invalid method parameter. '
'Available options: {}'.format(FINETUNING_METHODS))
(X_train, y_train) = (texts[0], labels[0])
(X_val, y_val) = (texts[1], labels[1])
(X_test, y_test) = (texts[2], labels[2])
checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \
.format(WEIGHTS_DIR, str(uuid.uuid4()))
f1_init_path = '{}/torchmoji-f1-init-{}.bin' \
.format(WEIGHTS_DIR, str(uuid.uuid4()))
if method in ['last', 'new']:
lr = 0.001
elif method in ['full', 'chain-thaw']:
lr = 0.0001
loss_op = nn.BCEWithLogitsLoss()
# Freeze layers if using last
if method == 'last':
model = freeze_layers(model, unfrozen_keyword='output_layer')
# Define optimizer, for chain-thaw we define it later (after freezing)
if method == 'last':
adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr)
elif method in ['full', 'new']:
# Add L2 regulation on embeddings only
special_params = [id(p) for p in model.embed.parameters()]
base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad]
embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad]
adam = optim.Adam([
{'params': base_params},
{'params': embed_parameters, 'weight_decay': embed_l2},
], lr=lr)
# Training
if verbose:
print('Method: {}'.format(method))
print('Classes: {}'.format(nb_classes))
if method == 'chain-thaw':
result = class_avg_chainthaw(model, nb_classes=nb_classes,
loss_op=loss_op,
train=(X_train, y_train),
val=(X_val, y_val),
test=(X_test, y_test),
batch_size=batch_size,
epoch_size=epoch_size,
nb_epochs=nb_epochs,
checkpoint_weight_path=checkpoint_path,
f1_init_weight_path=f1_init_path,
verbose=verbose)
else:
result = class_avg_tune_trainable(model, nb_classes=nb_classes,
loss_op=loss_op,
optim_op=adam,
train=(X_train, y_train),
val=(X_val, y_val),
test=(X_test, y_test),
epoch_size=epoch_size,
nb_epochs=nb_epochs,
batch_size=batch_size,
init_weight_path=f1_init_path,
checkpoint_weight_path=checkpoint_path,
verbose=verbose)
return model, result
def prepare_labels(y_train, y_val, y_test, iter_i, nb_classes):
# Relabel into binary classification
y_train_new = relabel(y_train, iter_i, nb_classes)
y_val_new = relabel(y_val, iter_i, nb_classes)
y_test_new = relabel(y_test, iter_i, nb_classes)
return y_train_new, y_val_new, y_test_new
def prepare_generators(X_train, y_train_new, X_val, y_val_new, batch_size, epoch_size):
# Create sample generators
# Make a fixed validation set to avoid fluctuations in validation
train_gen = get_data_loader(X_train, y_train_new, batch_size,
extended_batch_sampler=True)
val_gen = get_data_loader(X_val, y_val_new, epoch_size,
extended_batch_sampler=True)
X_val_resamp, y_val_resamp = next(iter(val_gen))
return train_gen, X_val_resamp, y_val_resamp
def class_avg_tune_trainable(model, nb_classes, loss_op, optim_op, train, val, test,
epoch_size, nb_epochs, batch_size,
init_weight_path, checkpoint_weight_path, patience=5,
verbose=True):
""" Finetunes the given model using the F1 measure.
# Arguments:
model: Model to be finetuned.
nb_classes: Number of classes in the given dataset.
train: Training data, given as a tuple of (inputs, outputs)
val: Validation data, given as a tuple of (inputs, outputs)
test: Testing data, given as a tuple of (inputs, outputs)
epoch_size: Number of samples in an epoch.
nb_epochs: Number of epochs.
batch_size: Batch size.
init_weight_path: Filepath where weights will be initially saved before
training each class. This file will be rewritten by the function.
checkpoint_weight_path: Filepath where weights will be checkpointed to
during training. This file will be rewritten by the function.
verbose: Verbosity flag.
# Returns:
F1 score of the trained model
"""
total_f1 = 0
nb_iter = nb_classes if nb_classes > 2 else 1
# Unpack args
X_train, y_train = train
X_val, y_val = val
X_test, y_test = test
# Save and reload initial weights after running for
# each class to avoid learning across classes
torch.save(model.state_dict(), init_weight_path)
for i in range(nb_iter):
if verbose:
print('Iteration number {}/{}'.format(i+1, nb_iter))
model.load_state_dict(torch.load(init_weight_path))
y_train_new, y_val_new, y_test_new = prepare_labels(y_train, y_val,
y_test, i, nb_classes)
train_gen, X_val_resamp, y_val_resamp = \
prepare_generators(X_train, y_train_new, X_val, y_val_new,
batch_size, epoch_size)
if verbose:
print("Training..")
fit_model(model, loss_op, optim_op, train_gen, [(X_val_resamp, y_val_resamp)],
nb_epochs, checkpoint_weight_path, patience, verbose=0)
# Reload the best weights found to avoid overfitting
# Wait a bit to allow proper closing of weights file
sleep(1)
model.load_state_dict(torch.load(checkpoint_weight_path))
# Evaluate
y_pred_val = model(X_val).cpu().numpy()
y_pred_test = model(X_test).cpu().numpy()
f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val,
y_test_new, y_pred_test)
if verbose:
print('f1_test: {}'.format(f1_test))
print('best_t: {}'.format(best_t))
total_f1 += f1_test
return total_f1 / nb_iter
def class_avg_chainthaw(model, nb_classes, loss_op, train, val, test, batch_size,
epoch_size, nb_epochs, checkpoint_weight_path,
f1_init_weight_path, patience=5,
initial_lr=0.001, next_lr=0.0001, verbose=True):
""" Finetunes given model using chain-thaw and evaluates using F1.
For a dataset with multiple classes, the model is trained once for
each class, relabeling those classes into a binary classification task.
The result is an average of all F1 scores for each class.
# Arguments:
model: Model to be finetuned.
nb_classes: Number of classes in the given dataset.
train: Training data, given as a tuple of (inputs, outputs)
val: Validation data, given as a tuple of (inputs, outputs)
test: Testing data, given as a tuple of (inputs, outputs)
batch_size: Batch size.
loss: Loss function to be used during training.
epoch_size: Number of samples in an epoch.
nb_epochs: Number of epochs.
checkpoint_weight_path: Filepath where weights will be checkpointed to
during training. This file will be rewritten by the function.
f1_init_weight_path: Filepath where weights will be saved to and
reloaded from before training each class. This ensures that
each class is trained independently. This file will be rewritten.
initial_lr: Initial learning rate. Will only be used for the first
training step (i.e. the softmax layer)
next_lr: Learning rate for every subsequent step.
seed: Random number generator seed.
verbose: Verbosity flag.
# Returns:
Averaged F1 score.
"""
# Unpack args
X_train, y_train = train
X_val, y_val = val
X_test, y_test = test
total_f1 = 0
nb_iter = nb_classes if nb_classes > 2 else 1
torch.save(model.state_dict(), f1_init_weight_path)
for i in range(nb_iter):
if verbose:
print('Iteration number {}/{}'.format(i+1, nb_iter))
model.load_state_dict(torch.load(f1_init_weight_path))
y_train_new, y_val_new, y_test_new = prepare_labels(y_train, y_val,
y_test, i, nb_classes)
train_gen, X_val_resamp, y_val_resamp = \
prepare_generators(X_train, y_train_new, X_val, y_val_new,
batch_size, epoch_size)
if verbose:
print("Training..")
# Train using chain-thaw
train_by_chain_thaw(model=model, train_gen=train_gen,
val_gen=[(X_val_resamp, y_val_resamp)],
loss_op=loss_op, patience=patience,
nb_epochs=nb_epochs,
checkpoint_path=checkpoint_weight_path,
initial_lr=initial_lr, next_lr=next_lr,
verbose=verbose)
# Evaluate
y_pred_val = model(X_val).cpu().numpy()
y_pred_test = model(X_test).cpu().numpy()
f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val,
y_test_new, y_pred_test)
if verbose:
print('f1_test: {}'.format(f1_test))
print('best_t: {}'.format(best_t))
total_f1 += f1_test
return total_f1 / nb_iter
|