import torch from torch import nn import matplotlib.pyplot as plt import numpy as np # import torch.nn as nn torch.set_printoptions(sci_mode=False) class MLP(nn.Module): def __init__(self, input_size=768, output_size=3, dropout_rate=.2, class_weights=None): super(MLP, self).__init__() self.class_weights = class_weights # self.bn1 = nn.BatchNorm1d(hidden_size) self.dropout = nn.Dropout(dropout_rate) self.linear = nn.Linear(input_size, output_size) # nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu') # nn.init.kaiming_normal_(self.fc2.weight) def forward(self, x): # return self.linear(self.dropout(x)) return self.dropout(self.linear(x)) def predict(self, x): _, predicted = torch.max(self.forward(x), 1) print('I am predict') return predicted def predict_proba(self, x): print('I am predict_proba') return self.forward(x) def get_loss_fn(self): return nn.CrossEntropyLoss(weight=self.class_weights, reduction='mean') if __name__ == '__main__': from datasets import load_dataset from sentence_transformers import SentenceTransformer import sys # from datetime import datetime # from collections import Counter import torch import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset from safetensors.torch import load_model, save_model from sklearn.utils.class_weight import compute_class_weight import warnings from train_classificator import ( # MLP, plot_labels_distribution, plot_training_metrics, train_model, eval_model ) warnings.filterwarnings("ignore") SEED = 1003200212 + 1 DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') dataset = load_dataset("CabraVC/vector_dataset_roberta-fine-tuned") # plot_labels_distribution(dataset # # , save_as_filename=f'plots/labels_distribution_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png' # ) input_size = len(dataset['train']['embeddings'][0]) learning_rate = 5e-4 weight_decay = 0 batch_size = 128 epochs = 40 class_weights = torch.tensor(compute_class_weight('balanced', classes=[0, 1, 2], y=dataset['train']['labels']), dtype=torch.float) ** .5 model = MLP(input_size=input_size, class_weights=class_weights) criterion = model.get_loss_fn() test_data = TensorDataset(torch.tensor(dataset['test']['embeddings']), torch.tensor(dataset['test']['labels'])) test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) loss, accuracy = eval_model(model, criterion, test_loader, test_data, show=False, # save_as_filename=f'plots/confusion_matrix_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png' ) optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=.2, patience=5, threshold=1e-4, min_lr=1e-7, verbose=True) train_data = TensorDataset(torch.tensor(dataset['train']['embeddings']), torch.tensor(dataset['train']['labels'])) train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) val_data = TensorDataset(torch.tensor(dataset['val']['embeddings']), torch.tensor(dataset['val']['labels'])) val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True) losses, accuracies = train_model(model, criterion, optimizer, lr_scheduler, train_loader, val_loader, train_data, val_data, epochs) plot_training_metrics(losses, accuracies # , save_as_filename=f'plots/training_metrics_plot_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png' ) test_data = TensorDataset(torch.tensor(dataset['test']['embeddings']), torch.tensor(dataset['test']['labels'])) test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) loss, accuracy = eval_model(model, criterion, test_loader, test_data, show=False # save_as_filename=f'plots/confusion_matrix_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png' ) # torch.save(model.state_dict(), f'models/linear_head.pth') # save_model(model, f'models/linear_head.safetensors') # load_model(model, f'models/linear_head.safetensors') # print(model) # dataset.push_to_hub(f'CabraVC/vector_dataset_stratified_ttv_split_{datetime.now().strftime("%Y-%m-%d_%H-%M")}', private=True)