Spaces:
Running
Running
"""Finetuning example. | |
Trains the torchMoji model on the kaggle insults dataset, using the 'chain-thaw' | |
finetuning method and the accuracy metric. See the blog post at | |
https://medium.com/@bjarkefelbo/what-can-we-learn-from-emojis-6beb165a5ea0 | |
for more information. Note that results may differ a bit due to slight | |
changes in preprocessing and train/val/test split. | |
The 'chain-thaw' method does the following: | |
0) Load all weights except for the softmax layer. Extend the embedding layer if | |
necessary, initialising the new weights with random values. | |
1) Freeze every layer except the last (softmax) layer and train it. | |
2) Freeze every layer except the first layer and train it. | |
3) Freeze every layer except the second etc., until the second last layer. | |
4) Unfreeze all layers and train entire model. | |
""" | |
from __future__ import print_function | |
import example_helper | |
import json | |
from torchmoji.model_def import torchmoji_transfer | |
from torchmoji.global_variables import PRETRAINED_PATH | |
from torchmoji.finetuning import ( | |
load_benchmark, | |
finetune) | |
DATASET_PATH = '../data/kaggle-insults/raw.pickle' | |
nb_classes = 2 | |
with open('../model/vocabulary.json', 'r') as f: | |
vocab = json.load(f) | |
# Load dataset. Extend the existing vocabulary with up to 10000 tokens from | |
# the training dataset. | |
data = load_benchmark(DATASET_PATH, vocab, extend_with=10000) | |
# Set up model and finetune. Note that we have to extend the embedding layer | |
# with the number of tokens added to the vocabulary. | |
model = torchmoji_transfer(nb_classes, PRETRAINED_PATH, extend_embedding=data['added']) | |
print(model) | |
model, acc = finetune(model, data['texts'], data['labels'], nb_classes, | |
data['batch_size'], method='chain-thaw') | |
print('Acc: {}'.format(acc)) | |