holiday_testing / test_models /train_classificator.py
svystun-taras's picture
created the updated web ui
0fdb130
import torch
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import numpy as np
# import torch.nn as nn
torch.set_printoptions(sci_mode=False)
# labels = ['buy', 'hold', 'sell']
class MLP(nn.Module):
def __init__(self, input_size=768, hidden_size=256, output_size=3, dropout_rate=.2, class_weights=None):
super(MLP, self).__init__()
self.class_weights = class_weights
self.activation = nn.ReLU()
# self.activation = nn.Tanh()
# self.activation = nn.LeakyReLU()
# self.activation = nn.Sigmoid()
self.bn1 = nn.BatchNorm1d(hidden_size)
self.dropout = nn.Dropout(dropout_rate)
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
# nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu')
# nn.init.kaiming_normal_(self.fc2.weight)
def forward(self, x):
input_is_dict = False
if isinstance(x, dict):
assert "sentence_embedding" in x
input_is_dict = True
x = x['sentence_embedding']
# print(x)
x = self.fc1(x)
x = self.bn1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.fc2(x)
if input_is_dict:
return {'logits': x}
return x
def predict(self, x):
_, predicted = torch.max(self.forward(x), 1)
print('I am predict')
return predicted
def predict_proba(self, x):
print('I am predict_proba')
return self.forward(x)
def get_loss_fn(self):
return nn.CrossEntropyLoss(weight=self.class_weights, reduction='mean')
def split_text(text, chunk_size=1200, chunk_overlap=200):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap,
length_function = len, separators=[" ", ",", "\n"]
)
text_chunks = text_splitter.create_documents([text])
return text_chunks
def plot_labels_distribution(dataset, save_as_filename=None):
plt.figure(figsize = (10, 6))
freqs, bins, _ = plt.hist([
dataset['train']['labels'],
dataset['val']['labels'],
dataset['test']['labels']
], label=['80% - train', '10% - val', '10% - test'], bins=[-.25, .25, .75, 1.25, 1.75, 2.25])
plt.legend(loc='upper left')
plt.xticks([bin - .25 for bin in bins], ['', 'Buy', '', 'Hold', '', 'Sell'], fontsize=16)
bin_centers = np.diff(bins) * .5 + bins[:-1]
for offset, freq in zip([-.135, 0, .135], freqs):
for fr, x in zip(freq, bin_centers):
height = int(fr)
if height:
plt.annotate("{}".format(height),
xy = (x + offset, height),
xytext = (0, .2),
textcoords = "offset points",
ha = 'center', va = 'bottom'
)
plt.title('Labels distribution')
if save_as_filename:
plt.savefig(save_as_filename)
plt.show()
def plot_training_metrics(losses, accuracies, show=False, save_as_filename=None):
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(losses['train'], label='Training Loss')
plt.plot(losses['val'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss over Epochs')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(accuracies['train'], label='Training Accuracy')
plt.plot(accuracies['val'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy over Epochs')
plt.legend()
plt.tight_layout()
if save_as_filename:
plt.savefig(save_as_filename)
if show:
plt.show()
def train_model(model, criterion, optimizer, lr_scheduler, train_loader, val_loader, train_data, val_data, epochs):
print_param = epochs // 8
losses = {
'train': [],
'val': []
}
accuracies = {
'train': [],
'val': []
}
for epoch in range(epochs):
model.train()
total_loss = 0.0
correct_predictions = 0
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct_predictions += (predicted == labels).sum().item()
losses['train'].append(total_loss / len(train_loader))
accuracies['train'].append(correct_predictions / len(train_data))
if epoch % print_param == 0:
print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(train_loader)}, Accuracy: {correct_predictions / len(train_data)}")
model.eval()
total_loss = 0.0
correct_predictions = 0
for inputs, labels in val_loader:
outputs = model(inputs)
loss = criterion(outputs, labels)
total_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct_predictions += (predicted == labels).sum().item()
losses['val'].append(total_loss / len(val_loader))
accuracies['val'].append(correct_predictions / len(val_data))
if epoch % print_param == 0:
print(f"Validation Loss: {total_loss / len(val_loader)}, Accuracy: {correct_predictions / len(val_data)}")
lr_scheduler.step(total_loss / len(val_loader))
return losses, accuracies
def eval_model(model, criterion, test_loader, test_data, show=False, save_as_filename=None):
total_loss = 0.0
correct_predictions = 0
all_labels = []
all_predictions = []
with torch.no_grad():
model.eval()
for inputs, labels in test_loader:
outputs = model(inputs)
loss = criterion(outputs, labels)
total_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct_predictions += (predicted == labels).sum().item()
probabilities = F.softmax(outputs, dim=1)
predicted_labels = torch.argmax(probabilities, dim=1).tolist()
all_labels.extend(labels)
all_predictions.extend(predicted_labels)
loss, accuracy = total_loss / len(test_loader), correct_predictions / len(test_data)
print(f'Model test loss: {loss:2f}, test accurracy: {accuracy * 100:1f}')
accuracy = accuracy_score(all_labels, all_predictions)
precision = precision_score(all_labels, all_predictions, average='weighted')
recall = recall_score(all_labels, all_predictions, average='weighted')
f1 = f1_score(all_labels, all_predictions, average='weighted')
confusion_mat = confusion_matrix(all_labels, all_predictions, normalize='true')
print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)
labels = ['hold', 'buy', 'sell']
if show:
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_mat, annot=True, fmt='.2%', cmap='Blues', xticklabels=labels, yticklabels=labels)
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.title('Confusion Matrix')
if save_as_filename:
plt.savefig(save_as_filename)
if show:
plt.show()
return loss, accuracy
if __name__ == '__main__':
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import sys
from datetime import datetime
from collections import Counter
from langchain.text_splitter import RecursiveCharacterTextSplitter
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from safetensors.torch import load_model, save_model
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
import warnings
warnings.filterwarnings("ignore")
model_name = 'all-distilroberta-v1'
# model_name = 'all-MiniLM-L6-v2'
model = SentenceTransformer(model_name)
dataset = load_dataset("CabraVC/vector_dataset_stratified_ttv_split_2023-12-05_21-07")
# plot_labels_distribution(dataset
# # , save_as_filename=f'plots/labels_distribution_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png'
# )
input_size = len(dataset['train']['embeddings'][0])
hidden_size = 256
dropout_rate = 0.2
learning_rate = 2 * 1e-4
batch_size = 256
epochs = 100
class_weights = torch.tensor(compute_class_weight('balanced', classes=[0, 1, 2], y=dataset['train']['labels']), dtype=torch.float) ** .5
model = MLP(input_size=input_size, hidden_size=hidden_size, dropout_rate=dropout_rate, class_weights=class_weights)
criterion = model.get_loss_fn()
# print(class_weights)
# criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=8 * 1e-2)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=.25, patience=10, threshold=5 * 1e-5, min_lr=1e-7, verbose=True)
train_data = TensorDataset(torch.tensor(dataset['train']['embeddings']), torch.tensor(dataset['train']['labels']))
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data = TensorDataset(torch.tensor(dataset['val']['embeddings']), torch.tensor(dataset['val']['labels']))
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
losses, accuracies = train_model(model, criterion, optimizer, lr_scheduler, train_loader, val_loader, train_data, val_data, epochs)
plot_training_metrics(losses, accuracies
# , save_as_filename=f'plots/training_metrics_plot_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png'
)
test_data = TensorDataset(torch.tensor(dataset['test']['embeddings']), torch.tensor(dataset['test']['labels']))
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
loss, accuracy = eval_model(model, criterion, test_loader, test_data,
# save_as_filename=f'plots/confusion_matrix_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.png'
)
# torch.save(model.state_dict(), f'models/head_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.pth')
# save_model(model, f'models/head_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.safetensors')
# load_model(model, f'models/head_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.safetensors')
# print(model)
# dataset.push_to_hub(f'CabraVC/vector_dataset_stratified_ttv_split_{datetime.now().strftime("%Y-%m-%d_%H-%M")}', private=True)