Spaces:
Paused
Paused
File size: 4,752 Bytes
0fdb130 501f2e5 0fdb130 501f2e5 0fdb130 501f2e5 0fdb130 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
# import torch.nn as nn
torch.set_printoptions(sci_mode=False)
class MLP(nn.Module):
def __init__(self, input_size=768, output_size=3, dropout_rate=.2, class_weights=None):
super(MLP, self).__init__()
self.class_weights = class_weights
# self.bn1 = nn.BatchNorm1d(hidden_size)
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(input_size, output_size)
# nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu')
# nn.init.kaiming_normal_(self.fc2.weight)
def forward(self, x):
# return self.linear(self.dropout(x))
return self.dropout(self.linear(x))
def predict(self, x):
_, predicted = torch.max(self.forward(x), 1)
print('I am predict')
return predicted
def predict_proba(self, x):
print('I am predict_proba')
return self.forward(x)
def get_loss_fn(self):
return nn.CrossEntropyLoss(weight=self.class_weights, reduction='mean')
if __name__ == '__main__':
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import sys
# from datetime import datetime
# from collections import Counter
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from safetensors.torch import load_model, save_model
from sklearn.utils.class_weight import compute_class_weight
import warnings
from train_classificator import (
# MLP,
plot_labels_distribution,
plot_training_metrics,
train_model,
eval_model
)
warnings.filterwarnings("ignore")
SEED = 1003200212 + 1
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dataset = load_dataset("CabraVC/vector_dataset_roberta-fine-tuned")
# plot_labels_distribution(dataset
# # , save_as_filename=f'plots/labels_distribution_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png'
# )
input_size = len(dataset['train']['embeddings'][0])
learning_rate = 5e-4
weight_decay = 0
batch_size = 128
epochs = 40
class_weights = torch.tensor(compute_class_weight('balanced', classes=[0, 1, 2], y=dataset['train']['labels']), dtype=torch.float) ** .5
model = MLP(input_size=input_size, class_weights=class_weights)
criterion = model.get_loss_fn()
test_data = TensorDataset(torch.tensor(dataset['test']['embeddings']), torch.tensor(dataset['test']['labels']))
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
loss, accuracy = eval_model(model, criterion, test_loader, test_data, show=False,
# save_as_filename=f'plots/confusion_matrix_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png'
)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=.2, patience=5, threshold=1e-4, min_lr=1e-7, verbose=True)
train_data = TensorDataset(torch.tensor(dataset['train']['embeddings']), torch.tensor(dataset['train']['labels']))
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data = TensorDataset(torch.tensor(dataset['val']['embeddings']), torch.tensor(dataset['val']['labels']))
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
losses, accuracies = train_model(model, criterion, optimizer, lr_scheduler, train_loader, val_loader, train_data, val_data, epochs)
plot_training_metrics(losses, accuracies
# , save_as_filename=f'plots/training_metrics_plot_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png'
)
test_data = TensorDataset(torch.tensor(dataset['test']['embeddings']), torch.tensor(dataset['test']['labels']))
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
loss, accuracy = eval_model(model, criterion, test_loader, test_data, show=False
# save_as_filename=f'plots/confusion_matrix_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png'
)
# torch.save(model.state_dict(), f'models/linear_head.pth')
# save_model(model, f'models/linear_head.safetensors')
# load_model(model, f'models/linear_head.safetensors')
# print(model)
# dataset.push_to_hub(f'CabraVC/vector_dataset_stratified_ttv_split_{datetime.now().strftime("%Y-%m-%d_%H-%M")}', private=True)
|