DeepMoji / examples /finetune_insults_chain-thaw.py
Pendrokar's picture
torchmoji code
86a83a2
raw
history blame
1.76 kB
"""Finetuning example.
Trains the torchMoji model on the kaggle insults dataset, using the 'chain-thaw'
finetuning method and the accuracy metric. See the blog post at
https://medium.com/@bjarkefelbo/what-can-we-learn-from-emojis-6beb165a5ea0
for more information. Note that results may differ a bit due to slight
changes in preprocessing and train/val/test split.
The 'chain-thaw' method does the following:
0) Load all weights except for the softmax layer. Extend the embedding layer if
necessary, initialising the new weights with random values.
1) Freeze every layer except the last (softmax) layer and train it.
2) Freeze every layer except the first layer and train it.
3) Freeze every layer except the second etc., until the second last layer.
4) Unfreeze all layers and train entire model.
"""
from __future__ import print_function
import example_helper
import json
from torchmoji.model_def import torchmoji_transfer
from torchmoji.global_variables import PRETRAINED_PATH
from torchmoji.finetuning import (
load_benchmark,
finetune)
DATASET_PATH = '../data/kaggle-insults/raw.pickle'
nb_classes = 2
with open('../model/vocabulary.json', 'r') as f:
vocab = json.load(f)
# Load dataset. Extend the existing vocabulary with up to 10000 tokens from
# the training dataset.
data = load_benchmark(DATASET_PATH, vocab, extend_with=10000)
# Set up model and finetune. Note that we have to extend the embedding layer
# with the number of tokens added to the vocabulary.
model = torchmoji_transfer(nb_classes, PRETRAINED_PATH, extend_embedding=data['added'])
print(model)
model, acc = finetune(model, data['texts'], data['labels'], nb_classes,
data['batch_size'], method='chain-thaw')
print('Acc: {}'.format(acc))