holiday_testing / test_models /setfit_model_finetune.py
svystun-taras's picture
created the updated web ui
0fdb130
from torch import nn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
def get_eval_metric(y_pred, y_test):
return {
'accuracy': accuracy_score(y_test, y_pred),
'precision': precision_score(y_test, y_pred, average='weighted'),
'recall': recall_score(y_test, y_pred, average='weighted'),
'f1': f1_score(y_test, y_pred, average='weighted'),
'confusion_mat': confusion_matrix(y_test, y_pred, normalize='true'),
}
class MLP(nn.Module):
def __init__(self, input_size=768, hidden_size=256, output_size=3, dropout_rate=.2, class_weights=None):
super(MLP, self).__init__()
self.class_weights = class_weights
self.activation = nn.ReLU()
# self.activation = nn.Tanh()
# self.activation = nn.LeakyReLU()
# self.activation = nn.Sigmoid()
self.bn1 = nn.BatchNorm1d(hidden_size)
self.dropout = nn.Dropout(dropout_rate)
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
# nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu')
# nn.init.kaiming_normal_(self.fc2.weight)
def forward(self, x):
input_is_dict = False
if isinstance(x, dict):
assert "sentence_embedding" in x
input_is_dict = True
x = x['sentence_embedding']
# print(x)
x = self.fc1(x)
x = self.bn1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.fc2(x)
if input_is_dict:
return {'logits': x}
return x
def predict(self, x):
_, predicted = torch.max(self.forward(x), 1)
print('I am predict')
return predicted
def predict_proba(self, x):
print('I am predict_proba')
return self.forward(x)
def get_loss_fn(self):
return nn.CrossEntropyLoss(weight=self.class_weights, reduction='mean')
if __name__ == '__main__':
from setfit.__init__ import SetFitModel, Trainer, TrainingArguments
from datasets import Dataset, load_dataset
from sentence_transformers import SentenceTransformer, models, util
from sentence_transformers.losses import BatchAllTripletLoss, BatchHardSoftMarginTripletLoss, BatchHardTripletLoss, BatchSemiHardTripletLoss
from sklearn.linear_model import LogisticRegression
import sys
import os
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from datetime import datetime
import torch.optim as optim
from pprint import pprint
from torch.utils.data import DataLoader, TensorDataset
from safetensors.torch import load_model, save_model
from itertools import chain
from time import perf_counter
from tqdm import trange
from collections import Counter
from sklearn.utils.class_weight import compute_class_weight
warnings.filterwarnings("ignore")
SEED = 1003200212 + 1
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
start = perf_counter()
dataset = load_dataset("CabraVC/vector_dataset_stratified_ttv_split_2023-12-05_21-07")
class_weights_vect = compute_class_weight('balanced', classes=[0, 1, 2], y=dataset['train']['labels'])
class_weights = torch.tensor(compute_class_weight('balanced', classes=[0, 1, 2], y=dataset['train']['labels']), dtype=torch.float).to(DEVICE) ** .5
model_body = SentenceTransformer('sentence-transformers/all-distilroberta-v1')
model_head = MLP(hidden_size=256, class_weights=class_weights) # 128 82%acc
model = SetFitModel(model_body=model_body,
model_head=model_head,
labels=dataset['train'].features['labels'].names).to(DEVICE)
train_ds = dataset['train']
val_ds = dataset['val'].select(range(128))
test_ds = dataset['test'].select(range(128))
train_args = TrainingArguments(
seed=SEED,
batch_size=(16, 24),
num_epochs=(15, 16), # 15 best
margin=.5, # .5, 1, .8 1.1 good, .5 best, .4 BEST
loss=BatchSemiHardTripletLoss,
use_amp=True,
body_learning_rate=(3e-6, 4e-5), # 5e-5 for smaller margin=.3, (2e-6, 2-3 e-5) best
l2_weight=7e-3,
evaluation_strategy='epoch',
end_to_end=True,
samples_per_label=4,
max_length=model.model_body.get_max_seq_length()
)
trainer = Trainer(
model=model,
args=train_args,
train_dataset=train_ds,
eval_dataset=val_ds,
metric=get_eval_metric,
column_mapping={'texts': 'text', 'labels': 'label'},
)
print('Test unseen data')
metrics = trainer.evaluate(test_ds)
pprint(metrics)
trainer.train()
print('Test on train data')
metrics = trainer.evaluate(train_ds)
pprint(metrics)
print('Test unseen data')
metrics = trainer.evaluate(test_ds)
pprint(metrics)
trainer.push_to_hub('CabraVC/emb_classifier_model',
private=True)
print('-' * 50)
print('Successfully trained the model.')
print(f'It took me: {(perf_counter() - start) // 60:.0f} mins {(perf_counter() - start) % 60:.0f} secs')